use std::boxed::Box;
use std::os::unix::net::SocketAddr;
use std::ptr::{self, NonNull};
use std::sync::atomic::{AtomicPtr, AtomicU32, Ordering};
#[cfg(target_os = "linux")]
use std::os::linux::net::SocketAddrExt;
#[cfg(target_os = "linux")]
use pidfd_util::{PidFd, PidFdExt};
use crate::ListenFds;
static LISTEN_FDS: AtomicU32 = AtomicU32::new(0);
static NOTIFY_SOCKET: AtomicPtr<SocketAddr> = AtomicPtr::new(ptr::null_mut());
#[inline]
pub fn listen_fds() -> ListenFds {
unsafe { ListenFds::new(LISTEN_FDS.swap(0, Ordering::AcqRel)) }
}
#[inline]
pub fn notify_socket() -> Option<Box<SocketAddr>> {
NonNull::new(NOTIFY_SOCKET.swap(ptr::null_mut(), Ordering::AcqRel)).map(|raw| unsafe { Box::from_raw(raw.as_ptr()) })
}
#[ctor::ctor]
pub unsafe fn process_env() {
use std::env;
use std::ffi::OsStr;
use std::str::FromStr;
fn parse<S: AsRef<OsStr>, F: FromStr>(string: S) -> Result<F, <F as FromStr>::Err> {
str::parse::<F>(unsafe { str::from_utf8_unchecked(string.as_ref().as_encoded_bytes()) })
}
if cfg_select! {
target_os = "linux" => env::var_os("LISTEN_PIDFDID")
.and_then(|var| parse::<_, u64>(var).ok())
.map(|ino| PidFd::from_self().ok()
.and_then(|fd| fd.get_id().ok()) == Some(ino))
.unwrap_or_else(|| env::var_os("LISTEN_PID")
.and_then(|var| parse(var).ok()) == Some(std::process::id())),
_ => env::var_os("LISTEN_PID").and_then(|var| parse(var).ok()) == Some(std::process::id()),
} && let Some(fds) = env::var_os("LISTEN_FDS").and_then(|var| parse(var).ok())
{
if cfg!(debug_assertions) {
assert_eq!(LISTEN_FDS.swap(fds, Ordering::AcqRel), 0);
} else {
LISTEN_FDS.store(fds, Ordering::Release);
}
}
if let Some(sock) = env::var_os("NOTIFY_SOCKET") {
if let Some(addr) = match sock.as_encoded_bytes()[0] {
b'/' => SocketAddr::from_pathname(sock).ok(),
#[cfg(target_os = "linux")]
b'@' => SocketAddr::from_abstract_name(&sock.as_encoded_bytes()[1..]).ok(),
_ => None,
} {
let raw = Box::into_raw(Box::new(addr));
if cfg!(debug_assertions) {
assert_eq!(NOTIFY_SOCKET.swap(raw, Ordering::AcqRel), ptr::null_mut());
} else {
NOTIFY_SOCKET.store(raw, Ordering::Release);
}
}
};
for var in [
"LISTEN_PID",
"LISTEN_PIDFDID",
"LISTEN_FDS",
"LISTEN_FDNAMES",
"NOTIFY_SOCKET",
] {
unsafe {
env::remove_var(var);
}
}
}
#[cfg(test)]
mod tests {
use std::fs::File;
use std::mem::forget;
use std::os::fd::{AsRawFd, IntoRawFd};
use std::path::Path;
#[cfg(target_os = "linux")]
use pidfd_util::{PidFd, PidFdExt};
use crate::tests::with_env;
#[cfg(target_os = "linux")]
#[test]
fn listen_pidfd_eq() {
with_env(
[
(
"LISTEN_PIDFDID",
format!("{}", PidFd::from_self().unwrap().get_id().unwrap()).as_str(),
),
("LISTEN_FDS", "1"),
],
|| {
let fds = super::listen_fds();
assert!(fds.len() > 0);
forget(fds);
},
);
}
#[cfg(target_os = "linux")]
#[test]
fn listen_pidfd_ne() {
with_env([("LISTEN_PIDFDID", "0"), ("LISTEN_FDS", "1")], || {
let fds = super::listen_fds();
assert_eq!(fds.len(), 0);
});
}
#[test]
fn listen_pid_eq() {
with_env(
[
("LISTEN_PID", format!("{}", std::process::id()).as_str()),
("LISTEN_FDS", "1"),
],
|| {
let fds = super::listen_fds();
assert!(fds.len() > 0);
forget(fds);
},
);
}
#[test]
fn listen_pid_ne() {
with_env([("LISTEN_PID", "0"), ("LISTEN_FDS", "1")], || {
let fds = super::listen_fds();
assert_eq!(fds.len(), 0);
});
}
#[test]
fn listen_fds() {
with_env(
[
("LISTEN_PID", format!("{}", std::process::id()).as_str()),
("LISTEN_FDS", "1"),
],
|| {
assert_eq!(File::open("/dev/null").unwrap().into_raw_fd(), super::ListenFds::START);
assert_eq!(
super::listen_fds().next().unwrap().unwrap().as_raw_fd(),
super::ListenFds::START
);
},
);
}
#[test]
fn notify_socket_path() {
let path = Path::new("/foo/bar/socket");
with_env([("NOTIFY_SOCKET", path)], || {
let notify = super::notify_socket().unwrap();
assert_eq!(notify.as_pathname().unwrap(), path);
});
}
#[cfg(target_os = "linux")]
#[test]
fn notify_socket_abstract() {
use std::os::linux::net::SocketAddrExt;
let name = "foobar";
with_env([("NOTIFY_SOCKET", &format!("@{}", name))], || {
let notify = super::notify_socket().unwrap();
assert_eq!(notify.as_abstract_name().unwrap(), name.as_bytes());
});
}
}