mod listener;
mod stream;
pub use {listener::*, stream::*};
#[cfg(feature = "tokio")]
pub mod tokio {
mod listener;
mod stream;
pub use {listener::*, stream::*};
}
use {
crate::{
assume_nonzero_slice, check_nonzero_slice,
local_socket::{ListenerOptions, Name, NameInner},
os::unix::{
c_wrappers,
ud_addr::{name_too_long, TerminatedUdAddr, UdAddr, SUN_LEN},
unixprelude::*,
},
timeout_expiry,
},
std::{
ffi::{CStr, OsStr},
fmt::{self, Debug, Formatter},
io,
mem::MaybeUninit,
num::NonZeroU8,
path::Path,
time::{Duration, Instant},
},
};
const CONN_TIMEOUT_MSG: &str = "timed out while connecting to local socket server";
#[derive(Clone, Default)]
struct ReclaimGuard(Box<[u8]>);
impl ReclaimGuard {
fn disarmed() -> Self { Self(Box::new([])) }
fn new(cond: bool, addr: TerminatedUdAddr<'_>) -> Self {
if !cond
|| addr.inner().path().is_empty()
|| cfg!(any(target_os = "linux", target_os = "android"))
&& addr.inner().path().first() == Some(&0)
{
return Self::disarmed();
}
Self(addr.path().to_owned().into_bytes_with_nul().into_boxed_slice())
}
#[cfg_attr(not(feature = "tokio"), allow(dead_code))]
fn take(&mut self) -> Self { Self(std::mem::take(&mut self.0)) }
fn forget(&mut self) { self.0 = Box::new([]); }
fn as_c_str(&self) -> Option<&CStr> {
(!self.0.is_empty()).then(|| unsafe { CStr::from_bytes_with_nul_unchecked(&self.0) })
}
}
impl Drop for ReclaimGuard {
fn drop(&mut self) {
if let Some(s) = self.as_c_str() {
let _ = c_wrappers::unlink(s);
}
}
}
impl Debug for ReclaimGuard {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
let s = self.as_c_str().map(|s| OsStr::from_bytes(s.to_bytes()));
f.debug_tuple("ReclaimGuard").field(&s).finish()
}
}
fn listen_and_maybe_overwrite<T>(
mut opts: ListenerOptions<'_>,
mut listen: impl FnMut(TerminatedUdAddr<'_>, &mut ListenerOptions<'_>) -> io::Result<T>,
) -> io::Result<T> {
let end = opts.get_max_spin_time().map(timeout_expiry).transpose()?;
dispatch_name(
&mut opts,
true,
|opts| opts.name.borrow(),
|opts| opts.get_max_spin_time_mut(),
|addr, opts| {
let mut first = true;
loop {
let err = match listen(addr, opts) {
Err(e) if keep_trying_to_overwrite(&e, opts) => e,
otherwise => break otherwise,
};
if !continue_spin_loop(end, opts.get_max_spin_time_mut()) && !first {
break Err(err);
}
first = false;
unlink_and_eat_noents(addr)?;
}
},
)
}
fn unlink_and_eat_noents(addr: TerminatedUdAddr<'_>) -> io::Result<()> {
match c_wrappers::unlink(addr.path()) {
Ok(()) => Ok(()),
Err(e) if e.kind() == io::ErrorKind::NotFound => Ok(()),
Err(e) => Err(e),
}
}
fn keep_trying_to_overwrite(e: &io::Error, options: &ListenerOptions<'_>) -> bool {
options.get_try_overwrite() && e.kind() == io::ErrorKind::AddrInUse
}
fn check_no_nul(s: &[u8]) -> io::Result<&[NonZeroU8]> {
let msg = "interior nul bytes are not allowed inside Unix domain socket names";
check_nonzero_slice(s).ok_or_else(|| io::Error::new(io::ErrorKind::InvalidInput, msg))
}
fn dispatch_name<O, T>(
o: &mut O,
create_dirs: bool,
mut get_name: impl FnMut(&mut O) -> Name<'_>,
mut max_spin_time: impl FnMut(&mut O) -> Option<&mut Duration>,
mut create: impl FnMut(TerminatedUdAddr<'_>, &mut O) -> io::Result<T>,
) -> io::Result<T> {
let mut addr = UdAddr::new();
match get_name(o).0 {
NameInner::UdSocketPath(path) => {
addr.init(check_no_nul(path.as_bytes())?)?;
create(addr.write_terminator(), o)
}
NameInner::UdSocketPseudoNs(name) => {
let name = name.as_bytes();
write_run_user(&mut addr, name)?;
match with_missing_dir_creat(
o,
create_dirs,
addr.write_terminator(),
&mut max_spin_time,
&mut create,
) {
Err(e) if fail_is_benign(&e) => {
let NameInner::UdSocketPseudoNs(name) = get_name(o).0 else { unreachable!() };
write_prefixed(&mut addr, tmpdir(), name.as_bytes())?;
with_missing_dir_creat(
o,
create_dirs,
addr.write_terminator(),
&mut max_spin_time,
&mut create,
)
}
otherwise => otherwise,
}
}
#[cfg(any(target_os = "linux", target_os = "android"))]
NameInner::UdSocketNs(name) => {
addr.init_namespaced(check_no_nul(&name)?)?;
create(addr.write_terminator(), o)
}
}
}
fn with_missing_dir_creat<O, T>(
options: &mut O,
create: bool,
addr: TerminatedUdAddr<'_>,
mut max_spin_time: impl FnMut(&mut O) -> Option<&mut Duration>,
mut f: impl FnMut(TerminatedUdAddr<'_>, &mut O) -> io::Result<T>,
) -> io::Result<T> {
let end = max_spin_time(options).copied().map(timeout_expiry).transpose()?;
let mut first = true;
loop {
let err = match f(addr, options) {
Err(e) if create && fail_is_benign(&e) => e,
otherwise => return otherwise,
};
if !continue_spin_loop(end, max_spin_time(options)) && !first {
break Err(err);
}
first = false;
create_missing_dirs(addr).then_some(()).ok_or(err)?;
}
}
fn create_missing_dirs(addr: TerminatedUdAddr<'_>) -> bool {
let path = Path::new(OsStr::from_bytes(addr.inner().path()));
let Some(dir) = path.parent() else { return false };
let false = dir.as_os_str().is_empty() else { return false };
match std::fs::create_dir_all(dir) {
Ok(()) => true,
Err(e) if e.kind() == io::ErrorKind::AlreadyExists => true,
Err(..) => false,
}
}
fn continue_spin_loop(end: Option<Instant>, spin_time: Option<&mut Duration>) -> bool {
let Some(end) = end else { return false };
let cur = Instant::now();
if cur >= end {
spin_time.map(|time| *time = Duration::ZERO);
return false;
}
spin_time.map(|time| *time = end.saturating_duration_since(cur));
true
}
const MAX_RUN_USER: usize = "/run/user//".len() + uid_t::MAX.ilog10() as usize + 1;
const RUN_USER_BUF: usize = MAX_RUN_USER + 1;
const NMCAP: usize = SUN_LEN - MAX_RUN_USER;
#[allow(clippy::arithmetic_side_effects, clippy::indexing_slicing)]
fn write_run_user(addr: &mut UdAddr, name: &[u8]) -> io::Result<()> {
if name.len() > NMCAP {
return Err(name_too_long());
}
addr.reset_len();
let start = unsafe { assume_nonzero_slice(b"/run/user/") };
unsafe { addr.push_slice(start) };
let uid_len = {
let buf = unsafe { addr.path_buf_mut() };
let mut idx = start.len();
let mut uid = unsafe { libc::getuid() };
loop {
buf[idx] = MaybeUninit::new((uid % 10) as u8 + b'0');
uid /= 10;
idx += 1;
if uid == 0 {
break;
}
}
buf[idx] = MaybeUninit::new(b'/');
buf[start.len()..idx].reverse();
idx + 1 - start.len()
};
unsafe { addr.incr_len(uid_len) };
let esc_start = addr.len();
unsafe { addr.push_slice_with_nuls(name) };
escape_nuls(&mut addr.path_mut()[esc_start..]);
Ok(())
}
#[allow(clippy::arithmetic_side_effects, clippy::indexing_slicing)]
fn write_prefixed(addr: &mut UdAddr, pfx: &[NonZeroU8], name: &[u8]) -> io::Result<()> {
if pfx.len() + name.len() > SUN_LEN {
return Err(name_too_long());
}
let name = check_no_nul(name)?;
addr.reset_len();
unsafe { addr.push_slice(pfx) };
unsafe { addr.push_slice(name) };
escape_nuls(&mut addr.path_mut()[pfx.len()..]);
Ok(())
}
fn escape_nuls(b: &mut [u8]) { b.iter_mut().filter(|c| **c == 0).for_each(|c| *c = b'_'); }
fn fail_is_benign(e: &io::Error) -> bool {
use io::ErrorKind::*;
matches!(e.kind(), NotFound | Unsupported) || e.raw_os_error() == Some(libc::ENOTDIR)
}
fn tmpdir<'p>() -> &'p [NonZeroU8] {
if cfg!(target_os = "android") {
let mut ptr = unsafe { libc::getenv(b"TMPDIR\0".as_ptr().cast()) };
if ptr.is_null() {
ptr = unsafe { libc::getenv(b"TEMPDIR\0".as_ptr().cast()) };
}
if ptr.is_null() {
return unsafe { assume_nonzero_slice(b"/data/local/tmp/") };
}
let len = unsafe { libc::strlen(ptr) };
unsafe { std::slice::from_raw_parts(ptr.cast(), len) }
} else {
unsafe { assume_nonzero_slice(b"/tmp/") }
}
}