use std::io;
use std::path::{Path, PathBuf};
use tokio::net::{UnixListener, UnixStream};
const SOCKET_MODE: u32 = 0o600;
pub fn parse_unix_scheme(spec: &str) -> Option<PathBuf> {
spec.strip_prefix("unix://").map(PathBuf::from)
}
#[derive(Debug)]
pub struct UnixSocketGuard {
path: PathBuf,
}
impl Drop for UnixSocketGuard {
fn drop(&mut self) {
let _ = std::fs::remove_file(&self.path);
}
}
pub async fn bind_unix_listener(path: &Path) -> io::Result<(UnixListener, UnixSocketGuard)> {
let listener = match UnixListener::bind(path) {
Ok(listener) => listener,
Err(e) if e.kind() == io::ErrorKind::AddrInUse => {
if UnixStream::connect(path).await.is_ok() {
return Err(io::Error::new(
io::ErrorKind::AddrInUse,
format!(
"unix socket {} is already in use by a running process",
path.display()
),
));
}
tracing::warn!(path = %path.display(), "removing stale unix socket");
std::fs::remove_file(path)?;
UnixListener::bind(path)?
}
Err(e) => return Err(e),
};
set_socket_permissions(path)?;
Ok((
listener,
UnixSocketGuard {
path: path.to_path_buf(),
},
))
}
fn set_socket_permissions(path: &Path) -> io::Result<()> {
use std::os::unix::fs::PermissionsExt;
std::fs::set_permissions(path, std::fs::Permissions::from_mode(SOCKET_MODE))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn parse_unix_scheme_strips_prefix() {
assert_eq!(
parse_unix_scheme("unix:///run/rsigma.sock"),
Some(PathBuf::from("/run/rsigma.sock"))
);
assert_eq!(parse_unix_scheme("stdin"), None);
assert_eq!(parse_unix_scheme("nats://host/subject"), None);
}
#[tokio::test]
async fn bind_creates_owner_only_socket_and_guard_unlinks() {
use std::os::unix::fs::PermissionsExt;
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("api.sock");
let (listener, guard) = bind_unix_listener(&path).await.unwrap();
assert!(path.exists());
let mode = std::fs::metadata(&path).unwrap().permissions().mode() & 0o777;
assert_eq!(mode, SOCKET_MODE);
drop(listener);
drop(guard);
assert!(!path.exists(), "guard should unlink the socket on drop");
}
#[tokio::test]
async fn bind_recovers_stale_socket() {
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("api.sock");
let (listener, guard) = bind_unix_listener(&path).await.unwrap();
std::mem::forget(guard); drop(listener);
assert!(path.exists());
let (_listener, _guard) = bind_unix_listener(&path).await.unwrap();
}
#[tokio::test]
async fn bind_rejects_live_socket() {
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("api.sock");
let (_listener, _guard) = bind_unix_listener(&path).await.unwrap();
let err = bind_unix_listener(&path).await.unwrap_err();
assert_eq!(err.kind(), std::io::ErrorKind::AddrInUse);
}
}