use std::os::unix::net::UnixDatagram;
use std::time::{Duration, Instant};
pub struct SdNotify {
sock: Option<UnixDatagram>,
watchdog_interval: Option<Duration>,
last_notify: Instant,
}
pub struct WatchdogNotifier {
sock: UnixDatagram,
interval: Duration,
last_notify: Instant,
}
impl SdNotify {
pub fn from_env() -> Self {
let sock = std::env::var("NOTIFY_SOCKET").ok().and_then(|addr| {
open_notify_socket(&addr)
.map_err(|e| {
crate::varta_warn!("sd_notify: could not open {addr:?}: {e}");
})
.ok()
});
let watchdog_interval = std::env::var("WATCHDOG_USEC")
.ok()
.and_then(|s| s.parse::<u64>().ok())
.filter(|&us| us > 0)
.map(|us| Duration::from_micros(us / 2));
Self {
sock,
watchdog_interval,
last_notify: Instant::now(),
}
}
pub fn ready(&mut self) {
self.send(b"READY=1\n");
}
pub fn watchdog_tick(&mut self) {
if let Some(interval) = self.watchdog_interval {
if self.last_notify.elapsed() >= interval {
self.send(b"WATCHDOG=1\n");
self.last_notify = Instant::now();
}
}
}
pub fn stopping(&mut self) {
self.send(b"STOPPING=1\n");
}
pub fn take_watchdog_notifier(&mut self) -> Option<WatchdogNotifier> {
let interval = self.watchdog_interval.take()?;
let sock = self.sock.as_ref()?.try_clone().ok()?;
Some(WatchdogNotifier {
sock,
interval,
last_notify: self.last_notify,
})
}
pub fn watchdog_half_interval(&self) -> Option<Duration> {
self.watchdog_interval
}
fn send(&self, msg: &[u8]) {
if let Some(ref sock) = self.sock {
let _ = sock.send(msg);
}
}
}
impl WatchdogNotifier {
pub fn tick(&mut self) {
if self.last_notify.elapsed() >= self.interval {
let _ = self.sock.send(b"WATCHDOG=1\n");
self.last_notify = Instant::now();
}
}
pub fn half_interval(&self) -> Duration {
self.interval
}
}
fn open_notify_socket(addr: &str) -> std::io::Result<UnixDatagram> {
let sock = UnixDatagram::unbound()?;
if let Some(name) = addr.strip_prefix('@') {
connect_abstract(&sock, name)?;
} else {
sock.connect(addr)?;
}
Ok(sock)
}
#[cfg(unix)]
fn connect_abstract(sock: &UnixDatagram, name: &str) -> std::io::Result<()> {
use std::os::unix::io::AsRawFd;
let name_bytes = name.as_bytes();
if name_bytes.len() >= 108 {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidInput,
"NOTIFY_SOCKET abstract name too long",
));
}
let mut addr_buf = [0u8; 110];
addr_buf[0] = 1;
addr_buf[1] = 0;
addr_buf[2] = 0;
addr_buf[3..3 + name_bytes.len()].copy_from_slice(name_bytes);
let addrlen = (2u32 + 1 + name_bytes.len() as u32) as libc_socklen_t;
extern "C" {
fn connect(
sockfd: std::ffi::c_int,
addr: *const std::ffi::c_void,
addrlen: libc_socklen_t,
) -> std::ffi::c_int;
}
let rc = unsafe {
connect(
sock.as_raw_fd(),
addr_buf.as_ptr() as *const std::ffi::c_void,
addrlen,
)
};
if rc == -1 {
Err(std::io::Error::last_os_error())
} else {
Ok(())
}
}
#[cfg(not(unix))]
fn connect_abstract(_sock: &UnixDatagram, _name: &str) -> std::io::Result<()> {
Err(std::io::Error::new(
std::io::ErrorKind::Unsupported,
"abstract sockets are only available on Linux",
))
}
#[allow(non_camel_case_types)]
type libc_socklen_t = u32;
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn no_op_when_env_unset() {
let prev = std::env::var("NOTIFY_SOCKET").ok();
unsafe { std::env::remove_var("NOTIFY_SOCKET") };
let mut n = SdNotify::from_env();
n.ready();
n.watchdog_tick();
n.stopping();
assert!(n.sock.is_none());
if let Some(v) = prev {
unsafe { std::env::set_var("NOTIFY_SOCKET", v) };
}
}
#[test]
fn watchdog_tick_does_not_send_before_interval() {
let mut n = SdNotify {
sock: None,
watchdog_interval: Some(Duration::from_secs(60)),
last_notify: Instant::now(),
};
n.watchdog_tick();
}
#[test]
fn take_watchdog_notifier_disarms_main_thread_emission() {
let mut n = SdNotify {
sock: None,
watchdog_interval: Some(Duration::from_micros(100)),
last_notify: Instant::now(),
};
let taken = n.take_watchdog_notifier();
assert!(taken.is_none());
assert!(
n.watchdog_half_interval().is_none(),
"take_watchdog_notifier must consume the interval even when sock is absent",
);
n.watchdog_tick();
}
#[test]
fn take_watchdog_notifier_clones_socket_and_emits_independently() {
use std::os::unix::net::UnixDatagram;
use std::time::Duration;
let (listener, sender) = UnixDatagram::pair().expect("socketpair for hermetic notify test");
listener
.set_read_timeout(Some(Duration::from_millis(200)))
.expect("set read timeout");
let mut n = SdNotify {
sock: Some(sender),
watchdog_interval: Some(Duration::from_micros(0)),
last_notify: Instant::now() - Duration::from_secs(1),
};
let mut wdt = n
.take_watchdog_notifier()
.expect("take_watchdog_notifier when sock + interval are set");
assert!(n.watchdog_half_interval().is_none());
n.ready();
let mut buf = [0u8; 64];
let nread = listener.recv(&mut buf).expect("recv READY=1 from main fd");
assert_eq!(&buf[..nread], b"READY=1\n");
wdt.tick();
let nread = listener
.recv(&mut buf)
.expect("recv WATCHDOG=1 from dup fd");
assert_eq!(&buf[..nread], b"WATCHDOG=1\n");
n.stopping();
let nread = listener
.recv(&mut buf)
.expect("recv STOPPING=1 from main fd");
assert_eq!(&buf[..nread], b"STOPPING=1\n");
}
}