microsd 0.2.0

Light‐weight systemd auxiliars
Documentation
use std::os::unix::net::SocketAddr;
use std::ptr::{self, NonNull};
use std::sync::atomic::{AtomicPtr, AtomicU32};

#[cfg(target_os = "linux")]
use std::os::linux::net::SocketAddrExt;

#[cfg(target_os = "linux")]
use pidfd_util::{PidFd, PidFdExt};

use crate::thin::{AtomicThinStr, ThinStr};
use crate::{ListenFdNames, ListenFds, atomic};

static LISTEN_FDS: AtomicU32 = AtomicU32::new(0);
static LISTEN_FDNAMES: AtomicThinStr = AtomicThinStr::new();
static NOTIFY_SOCKET: AtomicPtr<SocketAddr> = AtomicPtr::new(ptr::null_mut());

/// Retrieve listen file descriptors.
///
/// The first call to this function will return an iterator over the file
/// descriptors passed through the `LISTEN_FDS` environment variable by the
/// system manager. Subsequent calls will return an empty iterator.
///
/// # Example
///
/// ```
/// use std::os::unix::fs::FileTypeExt;
///
/// for fd in microsd::listen_fds().filter_map(|res| res.ok()) {
///     if fd.is_socket() {
///         // …
///     }
/// }
/// # Ok::<(), std::io::Error>(())
/// ```
#[inline]
pub fn listen_fds() -> ListenFds {
	unsafe { ListenFds::new(atomic::load!(LISTEN_FDS)) }
}

/// Retrieve listen file descriptor names.
///
/// # Example
///
/// ```,ignore
/// for (fd, name) in microsd::listen_fds().zip(microsd::listen_fd_names().unwrap()) {
///     if name == "stored" {
///         // …
///     }
/// }
/// # Ok::<(), std::io::Error>(())
/// ```
#[inline]
pub fn listen_fd_names() -> Option<ListenFdNames> {
	LISTEN_FDNAMES.load().map(ListenFdNames::new)
}

/// Retrieve notification socket address.
///
/// The first call to this function will return the system manager
/// notification socket address as a boxed [`SocketAddr`], as provided
/// through the `NOTIFY_SOCKET` environment variable, or `None` if
/// unavailable. Subsequent calls will return `None`.
///
/// The implementation currently supports only UNIX domain sockets,
/// both path‐based and abstract. `VSOCK` sockets are not supported.
///
/// # Example
///
/// ```,ignore
/// use std::os::unix::net::UnixDatagram;
///
/// let notify = UnixDatagram::unbound()?;
/// notify.connect_addr(&microsd::notify_socket().unwrap())?;
///
/// notify.send(b"READY=1")?;
/// # Ok::<(), std::io::Error>(())
/// ```
#[inline]
pub fn notify_socket() -> Option<Box<SocketAddr>> {
	NonNull::new(atomic::load!(NOTIFY_SOCKET))
		.map(|raw| unsafe { Box::from_raw(raw.as_ptr()) })
}

/// Process environment variables.
///
/// This function processes the `LISTEN_PIDFDID`, `LISTEN_PID`, `LISTEN_FDS`
/// and `NOTIFY_SOCKET` variables from the environment, removing them
/// afterwards.
///
/// The `LISTEN_FDNAMES` variables is removed, but its contents are currently
/// not used.
///
/// # Safety
///
/// This function is called automatically during programme initialisation and
/// generally not safe to call directly.
#[ctor::ctor]
pub(crate) unsafe fn process_env() {
	use std::ffi::{CStr, OsStr};
	use std::os::unix::ffi::OsStrExt;
	use std::str::FromStr;

	fn parse_env<S, F, R>(name: &S, func: F) -> Option<R>
	where
		S: ?Sized + AsRef<CStr>,
		F: FnOnce(&CStr) -> Option<R>, {
		NonNull::new(unsafe { libc::getenv(name.as_ref().as_ptr()) })
			.and_then(|ptr| func(unsafe { CStr::from_ptr(ptr.as_ptr()) }))
	}

	unsafe fn parse_unchecked<S, F>(cstr: &S) -> Result<F, <F as FromStr>::Err>
	where
		S: ?Sized + AsRef<CStr>,
		F: FromStr, {
		str::parse::<F>(unsafe { str::from_utf8_unchecked(cstr.as_ref().to_bytes()) })
	}

	if cfg_select! {
		// Try LISTEN_PIDFDID first, falling back to LISTEN_PID
		target_os = "linux" =>
			unsafe { parse_env(c"LISTEN_PIDFDID", |cstr| unsafe { parse_unchecked(cstr) }.ok()) }
				.map(|ino| PidFd::from_self().ok()
					.and_then(|fd| fd.get_id().ok()) == Some(ino))
				.unwrap_or_else(|| parse_env(c"LISTEN_PID", |cstr|
					unsafe { parse_unchecked(cstr).ok() }) == Some(std::process::id())),
		_ => parse_env(c"LISTEN_PID", |cstr| unsafe { parse_unchecked(cstr) }.ok()) == Some(std::process::id())
	} && let Some(fds) = parse_env(c"LISTEN_FDS", |cstr| unsafe { parse_unchecked(cstr) }.ok()) {
		atomic::store!(LISTEN_FDS, fds);

		if let Some(names) = parse_env(c"LISTEN_FDNAMES", ThinStr::from_cstr) {
			LISTEN_FDNAMES.store(names);
		}
	}

	if let Some(addr) = parse_env(c"NOTIFY_SOCKET", |cstr| {
		let bytes = cstr.to_bytes();
		match bytes[0] {
			b'/' => SocketAddr::from_pathname(OsStr::from_bytes(bytes)).ok(),
			#[cfg(target_os = "linux")]
			b'@' => SocketAddr::from_abstract_name(&bytes[1..]).ok(),
			_ => None,
		}
	}) {
		atomic::store!(NOTIFY_SOCKET, Box::into_raw(Box::new(addr)));
	}

	for var in [
		c"LISTEN_PID",
		c"LISTEN_PIDFDID",
		c"LISTEN_FDS",
		c"LISTEN_FDNAMES",
		c"NOTIFY_SOCKET",
	] {
		unsafe {
			libc::unsetenv(var.as_ptr());
		}
	}
}

#[cfg(test)]
mod tests {
	use std::fs::File;
	use std::mem::forget;
	use std::os::fd::{AsRawFd, IntoRawFd};
	use std::path::Path;

	#[cfg(target_os = "linux")]
	use pidfd_util::{PidFd, PidFdExt};

	use crate::tests::with_env;

	#[cfg(target_os = "linux")]
	#[test]
	fn listen_pidfd_eq() {
		with_env(
			[
				(
					"LISTEN_PIDFDID",
					format!("{}", PidFd::from_self().unwrap().get_id().unwrap()).as_str(),
				),
				("LISTEN_FDS", "1"),
			],
			|| {
				let fds = super::listen_fds();
				assert!(fds.len() > 0);
				forget(fds);
			},
		);
	}

	#[cfg(target_os = "linux")]
	#[test]
	fn listen_pidfd_ne() {
		with_env([("LISTEN_PIDFDID", "0"), ("LISTEN_FDS", "1")], || {
			let fds = super::listen_fds();
			assert_eq!(fds.len(), 0);
		});
	}

	#[test]
	fn listen_pid_eq() {
		with_env(
			[
				("LISTEN_PID", format!("{}", std::process::id()).as_str()),
				("LISTEN_FDS", "1"),
			],
			|| {
				let fds = super::listen_fds();
				assert!(fds.len() > 0);
				forget(fds);
			},
		);
	}

	#[test]
	fn listen_pid_ne() {
		with_env([("LISTEN_PID", "0"), ("LISTEN_FDS", "1")], || {
			let fds = super::listen_fds();
			assert_eq!(fds.len(), 0);
		});
	}

	#[test]
	fn listen_fds() {
		with_env(
			[
				("LISTEN_PID", format!("{}", std::process::id()).as_str()),
				("LISTEN_FDS", "1"),
			],
			|| {
				assert_eq!(File::open("/dev/null").unwrap().into_raw_fd(), super::ListenFds::START);
				assert_eq!(
					super::listen_fds().next().unwrap().unwrap().as_raw_fd(),
					super::ListenFds::START
				);
			},
		);
	}

	#[test]
	fn listen_fd_names() {
		with_env(
			[
				("LISTEN_PID", format!("{}", std::process::id()).as_str()),
				("LISTEN_FDS", "0"),
				("LISTEN_FDNAMES", "foo:bar::spam"),
			],
			|| {
				assert_eq!(
					super::listen_fd_names()
						.unwrap()
						.into_iter()
						.map(Into::into)
						.collect::<Vec<String>>(),
					vec!["foo", "bar", "", "spam"]
				);
			},
		);
	}

	#[test]
	fn notify_socket_path() {
		let path = Path::new("/foo/bar/socket");
		with_env([("NOTIFY_SOCKET", path)], || {
			let notify = super::notify_socket().unwrap();
			assert_eq!(notify.as_pathname().unwrap(), path);
		});
	}

	#[cfg(target_os = "linux")]
	#[test]
	fn notify_socket_abstract() {
		use std::os::linux::net::SocketAddrExt;

		let name = "foobar";
		with_env([("NOTIFY_SOCKET", &format!("@{}", name))], || {
			let notify = super::notify_socket().unwrap();
			assert_eq!(notify.as_abstract_name().unwrap(), name.as_bytes());
		});
	}
}