use std::env;
use std::ffi::{OsStr, OsString};
use std::io;
use std::os::unix::ffi::OsStrExt;
use std::os::unix::net::UnixDatagram;
pub fn notify(state: &str) -> io::Result<bool> {
let path: OsString = match env::var_os("NOTIFY_SOCKET") {
Some(p) => p,
None => return Ok(false),
};
let bytes = path.as_bytes();
if bytes.first() == Some(&b'@') {
send_abstract(&bytes[1..], state.as_bytes())?;
} else {
let sock = UnixDatagram::unbound()?;
sock.connect::<&OsStr>(path.as_os_str())?;
sock.send(state.as_bytes())?;
}
Ok(true)
}
fn send_abstract(name: &[u8], payload: &[u8]) -> io::Result<()> {
use std::mem::{MaybeUninit, size_of};
if name.len() > 107 {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"abstract socket name too long",
));
}
let fd = unsafe { libc::socket(libc::AF_UNIX, libc::SOCK_DGRAM | libc::SOCK_CLOEXEC, 0) };
if fd < 0 {
return Err(io::Error::last_os_error());
}
let mut addr: MaybeUninit<libc::sockaddr_un> = MaybeUninit::zeroed();
let addrlen = unsafe {
let p = addr.as_mut_ptr();
(*p).sun_family = libc::AF_UNIX as libc::sa_family_t;
let path_ptr = (*p).sun_path.as_mut_ptr() as *mut u8;
path_ptr.write(0);
std::ptr::copy_nonoverlapping(name.as_ptr(), path_ptr.add(1), name.len());
(size_of::<libc::sa_family_t>() + 1 + name.len()) as libc::socklen_t
};
let result = unsafe {
libc::sendto(
fd,
payload.as_ptr() as *const libc::c_void,
payload.len(),
0,
addr.as_ptr() as *const libc::sockaddr,
addrlen,
)
};
let err = if result < 0 {
Some(io::Error::last_os_error())
} else {
None
};
unsafe {
libc::close(fd);
}
match err {
Some(e) => Err(e),
None => Ok(()),
}
}
pub const STATE_READY: &str = "READY=1";
pub const STATE_STOPPING: &str = "STOPPING=1";
pub const STATE_RELOADING: &str = "RELOADING=1";
#[allow(dead_code)]
pub const STATE_WATCHDOG: &str = "WATCHDOG=1";
pub fn main_pid(pid: u32) -> io::Result<bool> {
notify(&format!("MAINPID={pid}"))
}
#[allow(dead_code)]
pub fn status(text: &str) -> io::Result<bool> {
notify(&format!("STATUS={text}"))
}
#[cfg(test)]
mod tests {
use super::*;
use serial_test::serial;
use std::os::unix::net::UnixDatagram;
use tempfile::TempDir;
fn with_notify_socket<R>(value: Option<&OsStr>, body: impl FnOnce() -> R) -> R {
let prev = env::var_os("NOTIFY_SOCKET");
unsafe {
match value {
Some(v) => env::set_var("NOTIFY_SOCKET", v),
None => env::remove_var("NOTIFY_SOCKET"),
}
}
let result = body();
unsafe {
match prev {
Some(v) => env::set_var("NOTIFY_SOCKET", v),
None => env::remove_var("NOTIFY_SOCKET"),
}
}
result
}
#[test]
#[serial(notify_socket)]
fn notify_no_socket_set_is_noop() {
let res = with_notify_socket(None, || notify(STATE_READY).expect("noop"));
assert!(
!res,
"notify must report `false` when NOTIFY_SOCKET is unset"
);
}
#[test]
#[serial(notify_socket)]
fn notify_writes_payload_to_socket() {
let dir = TempDir::new().expect("tempdir");
let socket_path = dir.path().join("notify.sock");
let listener = UnixDatagram::bind(&socket_path).expect("bind notify socket");
let sent = with_notify_socket(Some(socket_path.as_os_str()), || {
notify(STATE_READY).expect("notify must succeed")
});
assert!(sent, "notify must report `true` when datagram was sent");
let mut buf = [0u8; 64];
let (n, _addr) = listener.recv_from(&mut buf).expect("recv notify datagram");
assert_eq!(&buf[..n], STATE_READY.as_bytes());
}
#[test]
#[serial(notify_socket)]
fn notify_writes_payload_to_abstract_socket() {
let abstract_name = format!("sozu-test-notify-{}", std::process::id());
let listener = bind_abstract(abstract_name.as_bytes()).expect("bind abstract");
let env_value = format!("@{abstract_name}");
let sent = with_notify_socket(Some(OsStr::new(&env_value)), || {
notify(STATE_READY).expect("notify must succeed on abstract socket")
});
assert!(sent, "notify must report `true` for abstract socket");
let mut buf = [0u8; 64];
let (n, _) = listener
.recv_from(&mut buf)
.expect("recv notify datagram on abstract socket");
assert_eq!(&buf[..n], STATE_READY.as_bytes());
}
fn bind_abstract(name: &[u8]) -> io::Result<UnixDatagram> {
use std::mem::{MaybeUninit, size_of};
use std::os::fd::FromRawFd;
let fd = unsafe { libc::socket(libc::AF_UNIX, libc::SOCK_DGRAM | libc::SOCK_CLOEXEC, 0) };
if fd < 0 {
return Err(io::Error::last_os_error());
}
let (addr, addrlen) = unsafe {
let mut addr: MaybeUninit<libc::sockaddr_un> = MaybeUninit::zeroed();
let p = addr.as_mut_ptr();
(*p).sun_family = libc::AF_UNIX as libc::sa_family_t;
let path_ptr = (*p).sun_path.as_mut_ptr() as *mut u8;
path_ptr.write(0);
std::ptr::copy_nonoverlapping(name.as_ptr(), path_ptr.add(1), name.len());
(
addr,
(size_of::<libc::sa_family_t>() + 1 + name.len()) as libc::socklen_t,
)
};
let res = unsafe { libc::bind(fd, addr.as_ptr() as *const libc::sockaddr, addrlen) };
if res < 0 {
let err = io::Error::last_os_error();
unsafe { libc::close(fd) };
return Err(err);
}
Ok(unsafe { UnixDatagram::from_raw_fd(fd) })
}
}