use std::io::Write as _;
use std::os::unix::net::UnixStream;
use std::path::{Path, PathBuf};
use std::time::Duration;
const REFRESH_PAYLOAD: &[u8] = b"refresh\n";
const SOCKET_EXTENSION: &str = "sock";
const CONNECT_TIMEOUT: Duration = Duration::from_millis(50);
#[must_use]
pub fn socket_dir() -> PathBuf {
std::env::var_os("XDG_RUNTIME_DIR")
.map(PathBuf::from)
.unwrap_or_else(|| PathBuf::from("/tmp"))
}
fn euid() -> u32 {
unsafe { libc::geteuid() }
}
fn socket_prefix() -> String {
format!("modde-{}-", euid())
}
#[must_use]
pub fn gui_socket_path() -> PathBuf {
let pid = std::process::id();
socket_dir().join(format!("{}{pid}.{SOCKET_EXTENSION}", socket_prefix()))
}
pub fn cleanup_socket(path: &Path) {
let _ = std::fs::remove_file(path);
}
pub fn notify_refresh() -> usize {
notify_refresh_in(&socket_dir())
}
pub fn notify_refresh_in(dir: &Path) -> usize {
let prefix = socket_prefix();
let suffix = format!(".{SOCKET_EXTENSION}");
let Ok(entries) = std::fs::read_dir(dir) else {
return 0;
};
let mut delivered = 0usize;
for entry in entries.flatten() {
let name = entry.file_name();
let Some(name_str) = name.to_str() else {
continue;
};
if !name_str.starts_with(&prefix) || !name_str.ends_with(&suffix) {
continue;
}
let path = entry.path();
if notify_refresh_at(&path) {
delivered += 1;
} else {
let _ = std::fs::remove_file(&path);
}
}
delivered
}
pub fn notify_refresh_at(path: &Path) -> bool {
let Ok(mut stream) = UnixStream::connect(path) else {
return false;
};
let _ = stream.set_write_timeout(Some(CONNECT_TIMEOUT));
stream.write_all(REFRESH_PAYLOAD).is_ok()
}
#[cfg(test)]
mod tests {
use super::*;
use std::io::Read as _;
use std::os::unix::net::UnixListener;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::thread;
use std::time::Duration;
use tempfile::TempDir;
static FAKE_PID: AtomicUsize = AtomicUsize::new(1);
fn fake_socket_path(dir: &Path) -> PathBuf {
let pid = FAKE_PID.fetch_add(1, Ordering::Relaxed);
dir.join(format!("{}test{pid}.{SOCKET_EXTENSION}", socket_prefix()))
}
fn spawn_drain(listener: UnixListener) -> thread::JoinHandle<Vec<u8>> {
thread::spawn(move || {
let (mut stream, _) = listener.accept().unwrap();
let mut buf = Vec::new();
stream.read_to_end(&mut buf).unwrap();
buf
})
}
#[test]
fn notify_at_returns_false_when_no_listener() {
let tmp = TempDir::new().unwrap();
let path = fake_socket_path(tmp.path());
assert!(!notify_refresh_at(&path));
}
#[test]
fn notify_at_delivers_to_listener() {
let tmp = TempDir::new().unwrap();
let path = fake_socket_path(tmp.path());
let listener = UnixListener::bind(&path).unwrap();
let handle = spawn_drain(listener);
thread::sleep(Duration::from_millis(50));
assert!(notify_refresh_at(&path));
assert_eq!(handle.join().unwrap(), REFRESH_PAYLOAD);
}
#[test]
fn notify_in_returns_zero_for_empty_dir() {
let tmp = TempDir::new().unwrap();
assert_eq!(notify_refresh_in(tmp.path()), 0);
}
#[test]
fn notify_in_returns_zero_for_missing_dir() {
let tmp = TempDir::new().unwrap();
let missing = tmp.path().join("does-not-exist");
assert_eq!(notify_refresh_in(&missing), 0);
}
#[test]
fn notify_in_delivers_to_every_listener() {
let tmp = TempDir::new().unwrap();
let mut handles = Vec::new();
for _ in 0..3 {
let path = fake_socket_path(tmp.path());
let listener = UnixListener::bind(&path).unwrap();
handles.push(spawn_drain(listener));
}
thread::sleep(Duration::from_millis(50));
assert_eq!(notify_refresh_in(tmp.path()), 3);
for h in handles {
assert_eq!(h.join().unwrap(), REFRESH_PAYLOAD);
}
}
#[test]
fn notify_in_garbage_collects_stale_sockets_and_keeps_live_ones() {
let tmp = TempDir::new().unwrap();
let stale_a = fake_socket_path(tmp.path());
let stale_b = fake_socket_path(tmp.path());
std::fs::write(&stale_a, b"").unwrap();
std::fs::write(&stale_b, b"").unwrap();
let live_path = fake_socket_path(tmp.path());
let listener = UnixListener::bind(&live_path).unwrap();
let handle = spawn_drain(listener);
thread::sleep(Duration::from_millis(50));
let delivered = notify_refresh_in(tmp.path());
assert_eq!(delivered, 1, "only the live listener should receive");
assert!(!stale_a.exists(), "stale socket A should be GC'd");
assert!(!stale_b.exists(), "stale socket B should be GC'd");
assert!(live_path.exists(), "live socket must not be GC'd");
assert_eq!(handle.join().unwrap(), REFRESH_PAYLOAD);
}
#[test]
fn notify_in_skips_files_outside_the_user_prefix() {
let tmp = TempDir::new().unwrap();
let other_user = tmp.path().join("modde-99999-pid42.sock");
std::fs::write(&other_user, b"").unwrap();
let delivered = notify_refresh_in(tmp.path());
assert_eq!(delivered, 0);
assert!(other_user.exists(), "other-user file must be left alone");
}
#[test]
fn notify_in_skips_files_with_other_extensions() {
let tmp = TempDir::new().unwrap();
let lock = tmp.path().join(format!("{}pid1.lock", socket_prefix()));
std::fs::write(&lock, b"").unwrap();
assert_eq!(notify_refresh_in(tmp.path()), 0);
assert!(lock.exists(), "non-socket files must be left alone");
}
#[test]
fn cleanup_socket_removes_file() {
let tmp = TempDir::new().unwrap();
let path = fake_socket_path(tmp.path());
std::fs::write(&path, b"").unwrap();
assert!(path.exists());
cleanup_socket(&path);
assert!(!path.exists());
}
#[test]
fn cleanup_socket_is_idempotent() {
let tmp = TempDir::new().unwrap();
let path = tmp.path().join("never-existed.sock");
cleanup_socket(&path);
cleanup_socket(&path); }
#[test]
fn gui_socket_path_includes_pid() {
let path = gui_socket_path();
let name = path.file_name().unwrap().to_string_lossy().to_string();
let prefix = socket_prefix();
assert!(
name.starts_with(&prefix),
"expected prefix {prefix} in {name}"
);
assert!(name.ends_with(".sock"), "expected .sock suffix in {name}");
let pid = std::process::id().to_string();
assert!(name.contains(&pid), "expected pid {pid} in {name}");
}
}