use std::os::unix::net::SocketAddr;
use std::ptr::{self, NonNull};
use std::sync::atomic::{AtomicPtr, AtomicU32};
#[cfg(target_os = "linux")]
use std::os::linux::net::SocketAddrExt;
#[cfg(target_os = "linux")]
use pidfd_util::{PidFd, PidFdExt};
use crate::thin::{AtomicThinStr, ThinStr};
use crate::{ListenFdNames, ListenFds, atomic};
static LISTEN_FDS: AtomicU32 = AtomicU32::new(0);
static LISTEN_FDNAMES: AtomicThinStr = AtomicThinStr::new();
static NOTIFY_SOCKET: AtomicPtr<SocketAddr> = AtomicPtr::new(ptr::null_mut());
#[inline]
pub fn listen_fds() -> ListenFds {
unsafe { ListenFds::new(atomic::load!(LISTEN_FDS)) }
}
#[inline]
pub fn listen_fd_names() -> Option<ListenFdNames> {
LISTEN_FDNAMES.load().map(ListenFdNames::new)
}
#[inline]
pub fn notify_socket() -> Option<Box<SocketAddr>> {
NonNull::new(atomic::load!(NOTIFY_SOCKET))
.map(|raw| unsafe { Box::from_raw(raw.as_ptr()) })
}
#[ctor::ctor]
pub(crate) unsafe fn process_env() {
use std::ffi::{CStr, OsStr};
use std::os::unix::ffi::OsStrExt;
use std::str::FromStr;
fn parse_env<S, F, R>(name: &S, func: F) -> Option<R>
where
S: ?Sized + AsRef<CStr>,
F: FnOnce(&CStr) -> Option<R>, {
NonNull::new(unsafe { libc::getenv(name.as_ref().as_ptr()) })
.and_then(|ptr| func(unsafe { CStr::from_ptr(ptr.as_ptr()) }))
}
unsafe fn parse_unchecked<S, F>(cstr: &S) -> Result<F, <F as FromStr>::Err>
where
S: ?Sized + AsRef<CStr>,
F: FromStr, {
str::parse::<F>(unsafe { str::from_utf8_unchecked(cstr.as_ref().to_bytes()) })
}
if cfg_select! {
target_os = "linux" =>
unsafe { parse_env(c"LISTEN_PIDFDID", |cstr| unsafe { parse_unchecked(cstr) }.ok()) }
.map(|ino| PidFd::from_self().ok()
.and_then(|fd| fd.get_id().ok()) == Some(ino))
.unwrap_or_else(|| parse_env(c"LISTEN_PID", |cstr|
unsafe { parse_unchecked(cstr).ok() }) == Some(std::process::id())),
_ => parse_env(c"LISTEN_PID", |cstr| unsafe { parse_unchecked(cstr) }.ok()) == Some(std::process::id())
} && let Some(fds) = parse_env(c"LISTEN_FDS", |cstr| unsafe { parse_unchecked(cstr) }.ok()) {
atomic::store!(LISTEN_FDS, fds);
if let Some(names) = parse_env(c"LISTEN_FDNAMES", ThinStr::from_cstr) {
LISTEN_FDNAMES.store(names);
}
}
if let Some(addr) = parse_env(c"NOTIFY_SOCKET", |cstr| {
let bytes = cstr.to_bytes();
match bytes[0] {
b'/' => SocketAddr::from_pathname(OsStr::from_bytes(bytes)).ok(),
#[cfg(target_os = "linux")]
b'@' => SocketAddr::from_abstract_name(&bytes[1..]).ok(),
_ => None,
}
}) {
atomic::store!(NOTIFY_SOCKET, Box::into_raw(Box::new(addr)));
}
for var in [
c"LISTEN_PID",
c"LISTEN_PIDFDID",
c"LISTEN_FDS",
c"LISTEN_FDNAMES",
c"NOTIFY_SOCKET",
] {
unsafe {
libc::unsetenv(var.as_ptr());
}
}
}
#[cfg(test)]
mod tests {
use std::fs::File;
use std::mem::forget;
use std::os::fd::{AsRawFd, IntoRawFd};
use std::path::Path;
#[cfg(target_os = "linux")]
use pidfd_util::{PidFd, PidFdExt};
use crate::tests::with_env;
#[cfg(target_os = "linux")]
#[test]
fn listen_pidfd_eq() {
with_env(
[
(
"LISTEN_PIDFDID",
format!("{}", PidFd::from_self().unwrap().get_id().unwrap()).as_str(),
),
("LISTEN_FDS", "1"),
],
|| {
let fds = super::listen_fds();
assert!(fds.len() > 0);
forget(fds);
},
);
}
#[cfg(target_os = "linux")]
#[test]
fn listen_pidfd_ne() {
with_env([("LISTEN_PIDFDID", "0"), ("LISTEN_FDS", "1")], || {
let fds = super::listen_fds();
assert_eq!(fds.len(), 0);
});
}
#[test]
fn listen_pid_eq() {
with_env(
[
("LISTEN_PID", format!("{}", std::process::id()).as_str()),
("LISTEN_FDS", "1"),
],
|| {
let fds = super::listen_fds();
assert!(fds.len() > 0);
forget(fds);
},
);
}
#[test]
fn listen_pid_ne() {
with_env([("LISTEN_PID", "0"), ("LISTEN_FDS", "1")], || {
let fds = super::listen_fds();
assert_eq!(fds.len(), 0);
});
}
#[test]
fn listen_fds() {
with_env(
[
("LISTEN_PID", format!("{}", std::process::id()).as_str()),
("LISTEN_FDS", "1"),
],
|| {
assert_eq!(File::open("/dev/null").unwrap().into_raw_fd(), super::ListenFds::START);
assert_eq!(
super::listen_fds().next().unwrap().unwrap().as_raw_fd(),
super::ListenFds::START
);
},
);
}
#[test]
fn listen_fd_names() {
with_env(
[
("LISTEN_PID", format!("{}", std::process::id()).as_str()),
("LISTEN_FDS", "0"),
("LISTEN_FDNAMES", "foo:bar::spam"),
],
|| {
assert_eq!(
super::listen_fd_names()
.unwrap()
.into_iter()
.map(Into::into)
.collect::<Vec<String>>(),
vec!["foo", "bar", "", "spam"]
);
},
);
}
#[test]
fn notify_socket_path() {
let path = Path::new("/foo/bar/socket");
with_env([("NOTIFY_SOCKET", path)], || {
let notify = super::notify_socket().unwrap();
assert_eq!(notify.as_pathname().unwrap(), path);
});
}
#[cfg(target_os = "linux")]
#[test]
fn notify_socket_abstract() {
use std::os::linux::net::SocketAddrExt;
let name = "foobar";
with_env([("NOTIFY_SOCKET", &format!("@{}", name))], || {
let notify = super::notify_socket().unwrap();
assert_eq!(notify.as_abstract_name().unwrap(), name.as_bytes());
});
}
}