1use std::io::Write as _;
35use std::os::unix::net::UnixStream;
36use std::path::{Path, PathBuf};
37use std::time::Duration;
38
39const REFRESH_PAYLOAD: &[u8] = b"refresh\n";
40const SOCKET_EXTENSION: &str = "sock";
41const CONNECT_TIMEOUT: Duration = Duration::from_millis(50);
44
45#[must_use]
49pub fn socket_dir() -> PathBuf {
50 std::env::var_os("XDG_RUNTIME_DIR")
51 .map(PathBuf::from)
52 .unwrap_or_else(|| PathBuf::from("/tmp"))
53}
54
55fn euid() -> u32 {
56 unsafe { libc::geteuid() }
59}
60
61fn socket_prefix() -> String {
64 format!("modde-{}-", euid())
65}
66
67#[must_use]
70pub fn gui_socket_path() -> PathBuf {
71 let pid = std::process::id();
72 socket_dir().join(format!("{}{pid}.{SOCKET_EXTENSION}", socket_prefix()))
73}
74
75pub fn cleanup_socket(path: &Path) {
78 let _ = std::fs::remove_file(path);
79}
80
81pub fn notify_refresh() -> usize {
91 notify_refresh_in(&socket_dir())
92}
93
94pub fn notify_refresh_in(dir: &Path) -> usize {
98 let prefix = socket_prefix();
99 let suffix = format!(".{SOCKET_EXTENSION}");
100 let Ok(entries) = std::fs::read_dir(dir) else {
101 return 0;
102 };
103
104 let mut delivered = 0usize;
105 for entry in entries.flatten() {
106 let name = entry.file_name();
107 let Some(name_str) = name.to_str() else {
108 continue;
109 };
110 if !name_str.starts_with(&prefix) || !name_str.ends_with(&suffix) {
111 continue;
112 }
113 let path = entry.path();
114 if notify_refresh_at(&path) {
115 delivered += 1;
116 } else {
117 let _ = std::fs::remove_file(&path);
120 }
121 }
122 delivered
123}
124
125pub fn notify_refresh_at(path: &Path) -> bool {
128 let Ok(mut stream) = UnixStream::connect(path) else {
129 return false;
130 };
131 let _ = stream.set_write_timeout(Some(CONNECT_TIMEOUT));
132 stream.write_all(REFRESH_PAYLOAD).is_ok()
133}
134
135#[cfg(test)]
136mod tests {
137 use super::*;
138 use std::io::Read as _;
139 use std::os::unix::net::UnixListener;
140 use std::sync::atomic::{AtomicUsize, Ordering};
141 use std::thread;
142 use std::time::Duration;
143 use tempfile::TempDir;
144
145 static FAKE_PID: AtomicUsize = AtomicUsize::new(1);
148
149 fn fake_socket_path(dir: &Path) -> PathBuf {
150 let pid = FAKE_PID.fetch_add(1, Ordering::Relaxed);
151 dir.join(format!("{}test{pid}.{SOCKET_EXTENSION}", socket_prefix()))
152 }
153
154 fn spawn_drain(listener: UnixListener) -> thread::JoinHandle<Vec<u8>> {
157 thread::spawn(move || {
158 let (mut stream, _) = listener.accept().unwrap();
159 let mut buf = Vec::new();
160 stream.read_to_end(&mut buf).unwrap();
161 buf
162 })
163 }
164
165 #[test]
168 fn notify_at_returns_false_when_no_listener() {
169 let tmp = TempDir::new().unwrap();
170 let path = fake_socket_path(tmp.path());
171 assert!(!notify_refresh_at(&path));
172 }
173
174 #[test]
175 fn notify_at_delivers_to_listener() {
176 let tmp = TempDir::new().unwrap();
177 let path = fake_socket_path(tmp.path());
178
179 let listener = UnixListener::bind(&path).unwrap();
180 let handle = spawn_drain(listener);
181
182 thread::sleep(Duration::from_millis(50));
183 assert!(notify_refresh_at(&path));
184 assert_eq!(handle.join().unwrap(), REFRESH_PAYLOAD);
185 }
186
187 #[test]
190 fn notify_in_returns_zero_for_empty_dir() {
191 let tmp = TempDir::new().unwrap();
192 assert_eq!(notify_refresh_in(tmp.path()), 0);
193 }
194
195 #[test]
196 fn notify_in_returns_zero_for_missing_dir() {
197 let tmp = TempDir::new().unwrap();
200 let missing = tmp.path().join("does-not-exist");
201 assert_eq!(notify_refresh_in(&missing), 0);
202 }
203
204 #[test]
205 fn notify_in_delivers_to_every_listener() {
206 let tmp = TempDir::new().unwrap();
209 let mut handles = Vec::new();
210 for _ in 0..3 {
211 let path = fake_socket_path(tmp.path());
212 let listener = UnixListener::bind(&path).unwrap();
213 handles.push(spawn_drain(listener));
214 }
215
216 thread::sleep(Duration::from_millis(50));
217 assert_eq!(notify_refresh_in(tmp.path()), 3);
218 for h in handles {
219 assert_eq!(h.join().unwrap(), REFRESH_PAYLOAD);
220 }
221 }
222
223 #[test]
224 fn notify_in_garbage_collects_stale_sockets_and_keeps_live_ones() {
225 let tmp = TempDir::new().unwrap();
230
231 let stale_a = fake_socket_path(tmp.path());
232 let stale_b = fake_socket_path(tmp.path());
233 std::fs::write(&stale_a, b"").unwrap();
234 std::fs::write(&stale_b, b"").unwrap();
235
236 let live_path = fake_socket_path(tmp.path());
237 let listener = UnixListener::bind(&live_path).unwrap();
238 let handle = spawn_drain(listener);
239
240 thread::sleep(Duration::from_millis(50));
241 let delivered = notify_refresh_in(tmp.path());
242 assert_eq!(delivered, 1, "only the live listener should receive");
243 assert!(!stale_a.exists(), "stale socket A should be GC'd");
244 assert!(!stale_b.exists(), "stale socket B should be GC'd");
245 assert!(live_path.exists(), "live socket must not be GC'd");
246 assert_eq!(handle.join().unwrap(), REFRESH_PAYLOAD);
247 }
248
249 #[test]
250 fn notify_in_skips_files_outside_the_user_prefix() {
251 let tmp = TempDir::new().unwrap();
255 let other_user = tmp.path().join("modde-99999-pid42.sock");
256 std::fs::write(&other_user, b"").unwrap();
257
258 let delivered = notify_refresh_in(tmp.path());
259 assert_eq!(delivered, 0);
260 assert!(other_user.exists(), "other-user file must be left alone");
261 }
262
263 #[test]
264 fn notify_in_skips_files_with_other_extensions() {
265 let tmp = TempDir::new().unwrap();
268 let lock = tmp.path().join(format!("{}pid1.lock", socket_prefix()));
269 std::fs::write(&lock, b"").unwrap();
270
271 assert_eq!(notify_refresh_in(tmp.path()), 0);
272 assert!(lock.exists(), "non-socket files must be left alone");
273 }
274
275 #[test]
278 fn cleanup_socket_removes_file() {
279 let tmp = TempDir::new().unwrap();
280 let path = fake_socket_path(tmp.path());
281 std::fs::write(&path, b"").unwrap();
282 assert!(path.exists());
283 cleanup_socket(&path);
284 assert!(!path.exists());
285 }
286
287 #[test]
288 fn cleanup_socket_is_idempotent() {
289 let tmp = TempDir::new().unwrap();
292 let path = tmp.path().join("never-existed.sock");
293 cleanup_socket(&path);
294 cleanup_socket(&path); }
296
297 #[test]
298 fn gui_socket_path_includes_pid() {
299 let path = gui_socket_path();
302 let name = path.file_name().unwrap().to_string_lossy().to_string();
303 let prefix = socket_prefix();
304 assert!(
305 name.starts_with(&prefix),
306 "expected prefix {prefix} in {name}"
307 );
308 assert!(name.ends_with(".sock"), "expected .sock suffix in {name}");
309 let pid = std::process::id().to_string();
310 assert!(name.contains(&pid), "expected pid {pid} in {name}");
311 }
312}