microsd 0.0.1

Light‐weight systemd auxiliars
Documentation
use std::boxed::Box;
use std::os::unix::net::SocketAddr;
use std::ptr::{self, NonNull};
use std::sync::atomic::{AtomicPtr, AtomicU32, Ordering};

#[cfg(target_os = "linux")]
use std::os::linux::net::SocketAddrExt;

use crate::ListenFds;

static LISTEN_FDS: AtomicU32 = AtomicU32::new(0);
static NOTIFY_SOCKET: AtomicPtr<SocketAddr> = AtomicPtr::new(ptr::null_mut());

/// Retrieve listen file descriptors.
///
/// The first call to this function will return an iterator over the file
/// descriptors passed through the `LISTEN_FDS` environment variable by the
/// system manager. Subsequent calls will return an empty iterator.
///
/// # Example
///
/// ```
/// use std::os::unix::fs::FileTypeExt;
///
/// for fd in microsd::listen_fds().filter_map(|res| res.ok()) {
///     if fd.is_socket() {
///         // …
///     }
/// }
/// # Ok::<(), std::io::Error>(())
/// ```
#[inline]
pub fn listen_fds() -> ListenFds {
	unsafe { ListenFds::new(LISTEN_FDS.swap(0, Ordering::AcqRel)) }
}

/// Retrieve notification socket address.
///
/// The first call to this function will return the system manager
/// notification socket address as a boxed [`SocketAddr`], as provided
/// through the `NOTIFY_SOCKET` environment variable, or `None` if
/// unavailable. Subsequent calls will return `None`.
///
/// The implementation currently supports only UNIX domain sockets,
/// both path‐based and abstract. `VSOCK` sockets are not supported.
///
/// # Example
///
/// ```,ignore
/// use std::os::unix::net::UnixDatagram;
///
/// let notify = UnixDatagram::unbound()?;
/// notify.connect_addr(&microsd::notify_socket().unwrap())?;
///
/// notify.send(b"READY=1")?;
/// # Ok::<(), std::io::Error>(())
/// ```
#[inline]
pub fn notify_socket() -> Option<Box<SocketAddr>> {
	NonNull::new(NOTIFY_SOCKET.swap(ptr::null_mut(), Ordering::AcqRel)).map(|raw| unsafe { Box::from_raw(raw.as_ptr()) })
}

/// Process environment variables.
///
/// This function processes the `LISTEN_PID`, `LISTEN_FDS` and `NOTIFY_SOCKET`
/// variables from the environment, removing them afterwards.
///
/// The `LISTEN_PIDFDID` and `LISTEN_FDNAMES` variables are removed, but their
/// contents are currently not used.
///
/// # Safety
///
/// This function is called automatically during programme initialisation and
/// generally not safe to call directly.
#[ctor::ctor]
pub unsafe fn process_env() {
	use std::env;
	use std::ffi::OsStr;
	use std::str::FromStr;

	fn parse<S: AsRef<OsStr>, F: FromStr>(string: S) -> Result<F, <F as FromStr>::Err> {
		str::parse::<F>(unsafe { str::from_utf8_unchecked(string.as_ref().as_encoded_bytes()) })
	}

	if env::var_os("LISTEN_PID").and_then(|var| parse(var).ok()) == Some(std::process::id())
		&& let Some(fds) = env::var_os("LISTEN_FDS").and_then(|var| parse(var).ok())
	{
		if cfg!(debug_assertions) {
			assert_eq!(LISTEN_FDS.swap(fds, Ordering::AcqRel), 0);
		} else {
			LISTEN_FDS.store(fds, Ordering::Release);
		}
	}

	if let Some(sock) = env::var_os("NOTIFY_SOCKET") {
		// This may involve quite a bit of unnecessary allocation and memcpy’ing
		if let Some(addr) = match sock.as_encoded_bytes()[0] {
			b'/' => SocketAddr::from_pathname(sock).ok(),
			#[cfg(target_os = "linux")]
			b'@' => SocketAddr::from_abstract_name(&sock.as_encoded_bytes()[1..]).ok(),
			_ => None,
		} {
			let raw = Box::into_raw(Box::new(addr));
			if cfg!(debug_assertions) {
				assert_eq!(NOTIFY_SOCKET.swap(raw, Ordering::AcqRel), ptr::null_mut());
			} else {
				NOTIFY_SOCKET.store(raw, Ordering::Release);
			}
		}
	};

	for var in [
		"LISTEN_PID",
		"LISTEN_PIDFDID",
		"LISTEN_FDS",
		"LISTEN_FDNAMES",
		"NOTIFY_SOCKET",
	] {
		unsafe {
			env::remove_var(var);
		}
	}
}

#[cfg(test)]
mod tests {
	use std::fs::File;
	use std::mem::forget;
	use std::os::fd::{AsRawFd, IntoRawFd};
	use std::path::Path;

	use crate::tests::with_env;

	#[test]
	fn listen_pid_eq() {
		with_env(
			[
				("LISTEN_PID", format!("{}", std::process::id()).as_str()),
				("LISTEN_FDS", "1"),
			],
			|| {
				let fds = super::listen_fds();
				assert!(fds.len() > 0);
				forget(fds);
			},
		);
	}

	#[test]
	fn listen_pid_ne() {
		with_env([("LISTEN_PID", "0"), ("LISTEN_FDS", "1")], || {
			let fds = super::listen_fds();
			assert_eq!(fds.len(), 0);
		});
	}

	#[test]
	fn listen_fds() {
		with_env(
			[
				("LISTEN_PID", format!("{}", std::process::id()).as_str()),
				("LISTEN_FDS", "1"),
			],
			|| {
				assert_eq!(File::open("/dev/null").unwrap().into_raw_fd(), super::ListenFds::START);
				assert_eq!(
					super::listen_fds().next().unwrap().unwrap().as_raw_fd(),
					super::ListenFds::START
				);
			},
		);
	}

	#[test]
	fn notify_socket_path() {
		let path = Path::new("/foo/bar/socket");
		with_env([("NOTIFY_SOCKET", path)], || {
			let notify = super::notify_socket().unwrap();
			assert_eq!(notify.as_pathname().unwrap(), path);
		});
	}

	#[cfg(target_os = "linux")]
	#[test]
	fn notify_socket_abstract() {
		use std::os::linux::net::SocketAddrExt;

		let name = "foobar";
		with_env([("NOTIFY_SOCKET", &format!("@{}", name))], || {
			let notify = super::notify_socket().unwrap();
			assert_eq!(notify.as_abstract_name().unwrap(), name.as_bytes());
		});
	}
}