1use std::io::{self, BufRead, BufReader, Read, Write};
10use std::path::{Path, PathBuf};
11use std::sync::{Arc, Mutex};
12
13use interprocess::local_socket::{
14 prelude::*, Listener, ListenerNonblockingMode, ListenerOptions, Stream,
15};
16
17use crate::workspace::encode_path_for_filename;
18
19#[cfg(unix)]
21mod platform_unix;
22#[cfg(windows)]
23mod platform_windows;
24
25#[cfg(unix)]
26use platform_unix as platform;
27#[cfg(windows)]
28use platform_windows as platform;
29
30#[derive(Debug, Clone)]
32pub struct SocketPaths {
33 pub data: PathBuf,
35 pub control: PathBuf,
37 pub pid: PathBuf,
39}
40
41impl SocketPaths {
42 pub fn socket_directory() -> io::Result<PathBuf> {
44 platform::get_socket_dir()
45 }
46
47 pub fn for_working_dir(working_dir: &Path) -> io::Result<Self> {
49 let socket_dir = platform::get_socket_dir()?;
50 let encoded = encode_path_for_filename(working_dir);
51
52 Ok(Self {
53 data: socket_dir.join(format!("{}.data.sock", encoded)),
54 control: socket_dir.join(format!("{}.ctrl.sock", encoded)),
55 pid: socket_dir.join(format!("{}.pid", encoded)),
56 })
57 }
58
59 pub fn for_session_name(name: &str) -> io::Result<Self> {
61 let socket_dir = platform::get_socket_dir()?;
62 Ok(Self::for_session_name_in_dir(name, &socket_dir))
63 }
64
65 pub fn for_session_name_in_dir(name: &str, socket_dir: &Path) -> Self {
68 Self {
69 data: socket_dir.join(format!("{}.data.sock", name)),
70 control: socket_dir.join(format!("{}.ctrl.sock", name)),
71 pid: socket_dir.join(format!("{}.pid", name)),
72 }
73 }
74
75 pub fn exists(&self) -> bool {
77 self.data.exists() && self.control.exists()
78 }
79
80 pub fn write_pid(&self, pid: u32) -> io::Result<()> {
82 std::fs::write(&self.pid, pid.to_string())
83 }
84
85 pub fn read_pid(&self) -> io::Result<Option<u32>> {
87 if !self.pid.exists() {
88 return Ok(None);
89 }
90 let content = std::fs::read_to_string(&self.pid)?;
91 Ok(content.trim().parse().ok())
92 }
93
94 pub fn is_server_alive(&self) -> bool {
96 use crate::server::daemon::is_process_running;
97
98 if let Ok(Some(pid)) = self.read_pid() {
100 if is_process_running(pid) {
101 return true;
102 }
103 }
104
105 if self.exists() {
107 return platform::check_server_by_connect(&self.control);
108 }
109
110 false
111 }
112
113 pub fn cleanup_if_stale(&self) -> bool {
116 if self.exists() && !self.is_server_alive() {
117 #[allow(clippy::let_underscore_must_use)]
119 let _ = self.cleanup();
120 true
121 } else {
122 false
123 }
124 }
125
126 pub fn cleanup(&self) -> io::Result<()> {
128 if self.data.exists() {
129 std::fs::remove_file(&self.data)?;
130 }
131 if self.control.exists() {
132 std::fs::remove_file(&self.control)?;
133 }
134 if self.pid.exists() {
135 std::fs::remove_file(&self.pid)?;
136 }
137 Ok(())
138 }
139}
140
141type LocalStream = Stream;
143type LocalListener = Listener;
144
145pub struct ServerListener {
147 data_listener: LocalListener,
148 control_listener: LocalListener,
149 paths: SocketPaths,
150}
151
152impl ServerListener {
153 pub fn bind(paths: SocketPaths) -> io::Result<Self> {
155 tracing::debug!("ServerListener::bind starting for {:?}", paths.data);
156
157 paths.cleanup()?;
159
160 if let Some(parent) = paths.data.parent() {
162 tracing::debug!("Creating socket directory: {:?}", parent);
163 std::fs::create_dir_all(parent)?;
164 }
165
166 let data_name = platform::socket_name_for_path(&paths.data)?;
167 let control_name = platform::socket_name_for_path(&paths.control)?;
168
169 tracing::debug!("Creating data listener...");
170 let data_listener = ListenerOptions::new()
171 .name(data_name)
172 .create_sync()
173 .map_err(|e| {
174 tracing::error!("Failed to create data listener: {}", e);
175 io::Error::new(io::ErrorKind::AddrInUse, e.to_string())
176 })?;
177
178 tracing::debug!("Creating control listener...");
179 let control_listener = ListenerOptions::new()
180 .name(control_name)
181 .create_sync()
182 .map_err(|e| {
183 tracing::error!("Failed to create control listener: {}", e);
184 io::Error::new(io::ErrorKind::AddrInUse, e.to_string())
185 })?;
186
187 #[cfg(windows)]
190 {
191 tracing::debug!("Writing marker files...");
192 std::fs::write(&paths.data, "socket")?;
193 std::fs::write(&paths.control, "socket")?;
194 }
195
196 tracing::info!("Server listening on {:?}", paths.data);
197
198 Ok(Self {
199 data_listener,
200 control_listener,
201 paths,
202 })
203 }
204
205 pub fn accept(&mut self) -> io::Result<Option<ServerConnection>> {
208 self.control_listener
211 .set_nonblocking(ListenerNonblockingMode::Accept)?;
212
213 let control_stream = match self.control_listener.accept() {
214 Ok(stream) => stream,
215 Err(e) if e.kind() == io::ErrorKind::WouldBlock => {
216 return Ok(None);
217 }
218 #[cfg(windows)]
219 Err(e) if platform_windows::is_transient_pipe_error(&e) => {
220 return Ok(None);
221 }
222 Err(e) => return Err(e),
223 };
224
225 self.data_listener
227 .set_nonblocking(ListenerNonblockingMode::Neither)?;
228 let data_stream = self.data_listener.accept()?;
229
230 #[cfg(not(windows))]
234 {
235 #[allow(clippy::let_underscore_must_use)]
237 let _ = data_stream.set_nonblocking(true);
238 control_stream.set_nonblocking(true)?;
239 }
240
241 Ok(Some(ServerConnection {
242 data: StreamWrapper::new(data_stream),
243 control: StreamWrapper::new(control_stream),
244 }))
245 }
246
247 pub fn paths(&self) -> &SocketPaths {
249 &self.paths
250 }
251}
252
253impl Drop for ServerListener {
254 fn drop(&mut self) {
255 #[allow(clippy::let_underscore_must_use)]
257 let _ = self.paths.cleanup();
258 }
259}
260
261#[derive(Clone)]
264pub struct StreamWrapper(Arc<Mutex<LocalStream>>);
265
266impl StreamWrapper {
267 fn new(stream: LocalStream) -> Self {
269 Self(Arc::new(Mutex::new(stream)))
270 }
271
272 pub fn set_nonblocking(&self, nonblocking: bool) -> io::Result<()> {
274 self.0
275 .lock()
276 .map_err(|_| io::Error::other("mutex poisoned"))?
277 .set_nonblocking(nonblocking)
278 }
279
280 pub fn write_all(&self, buf: &[u8]) -> io::Result<()> {
282 let mut guard = self
283 .0
284 .lock()
285 .map_err(|_| io::Error::other("mutex poisoned"))?;
286 Write::write_all(&mut *guard, buf)
287 }
288
289 pub fn flush(&self) -> io::Result<()> {
291 let mut guard = self
292 .0
293 .lock()
294 .map_err(|_| io::Error::other("mutex poisoned"))?;
295 Write::flush(&mut *guard)
296 }
297
298 pub fn try_read(&self, buf: &mut [u8]) -> io::Result<usize> {
300 let mut guard = match self.0.try_lock() {
301 Ok(g) => g,
302 Err(std::sync::TryLockError::WouldBlock) => {
303 return Err(io::Error::new(
304 io::ErrorKind::WouldBlock,
305 "stream busy (mutex contended)",
306 ));
307 }
308 Err(std::sync::TryLockError::Poisoned(_)) => {
309 return Err(io::Error::other("mutex poisoned"));
310 }
311 };
312
313 platform::try_read_nonblocking(&mut guard, buf)
314 }
315}
316
317#[inline]
319fn map_windows_pipe_error(result: io::Result<usize>) -> io::Result<usize> {
320 match result {
321 #[cfg(windows)]
322 Err(e) if platform_windows::is_transient_pipe_error(&e) => {
323 Err(io::Error::new(io::ErrorKind::WouldBlock, e))
324 }
325 other => other,
326 }
327}
328
329impl Read for StreamWrapper {
330 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
331 let result = self
332 .0
333 .lock()
334 .map_err(|_| io::Error::other("mutex poisoned"))?
335 .read(buf);
336 map_windows_pipe_error(result)
337 }
338}
339
340impl Read for &StreamWrapper {
341 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
342 let result = self
343 .0
344 .lock()
345 .map_err(|_| io::Error::other("mutex poisoned"))?
346 .read(buf);
347 map_windows_pipe_error(result)
348 }
349}
350
351pub struct ServerConnection {
353 pub data: StreamWrapper,
355 pub control: StreamWrapper,
357}
358
359impl ServerConnection {
360 pub fn read_data(&self, buf: &mut [u8]) -> io::Result<usize> {
362 self.data.try_read(buf)
363 }
364
365 pub fn write_data(&self, buf: &[u8]) -> io::Result<()> {
367 self.data.write_all(buf)?;
368 self.data.flush()
369 }
370
371 pub fn read_control(&self) -> io::Result<Option<String>> {
373 #[cfg(not(windows))]
376 self.control.set_nonblocking(false)?;
377 let mut reader = BufReader::new(&self.control);
378 let mut line = String::new();
379 match reader.read_line(&mut line) {
380 Ok(0) => Ok(None), Ok(_) => Ok(Some(line)),
382 Err(e) => Err(e),
383 }
384 }
385
386 pub fn write_control(&self, msg: &str) -> io::Result<()> {
402 #[cfg(not(windows))]
406 let restore_nonblocking = self.control.set_nonblocking(false).is_ok();
407
408 let result = (|| {
409 self.control.write_all(msg.as_bytes())?;
410 if !msg.ends_with('\n') {
411 self.control.write_all(b"\n")?;
412 }
413 self.control.flush()
414 })();
415
416 #[cfg(not(windows))]
418 if restore_nonblocking {
419 #[allow(clippy::let_underscore_must_use)]
420 let _ = self.control.set_nonblocking(true);
421 }
422
423 result
424 }
425}
426
427pub struct ClientConnection {
429 pub data: StreamWrapper,
431 pub control: StreamWrapper,
433}
434
435impl ClientConnection {
436 pub fn connect(paths: &SocketPaths) -> io::Result<Self> {
438 let control_name = platform::socket_name_for_path(&paths.control)?;
439 let data_name = platform::socket_name_for_path(&paths.data)?;
440
441 let control = Stream::connect(control_name)
443 .map_err(|e| io::Error::new(io::ErrorKind::ConnectionRefused, e.to_string()))?;
444
445 let data = Stream::connect(data_name)
446 .map_err(|e| io::Error::new(io::ErrorKind::ConnectionRefused, e.to_string()))?;
447
448 Ok(Self {
449 data: StreamWrapper::new(data),
450 control: StreamWrapper::new(control),
451 })
452 }
453
454 pub fn set_data_nonblocking(&self, nonblocking: bool) -> io::Result<()> {
456 self.data.set_nonblocking(nonblocking)
457 }
458
459 pub fn read_data(&self, buf: &mut [u8]) -> io::Result<usize> {
461 self.data.try_read(buf)
462 }
463
464 pub fn write_data(&self, buf: &[u8]) -> io::Result<()> {
466 self.data.write_all(buf)?;
467 self.data.flush()
468 }
469
470 pub fn read_control(&self) -> io::Result<Option<String>> {
472 let mut reader = BufReader::new(&self.control);
473 let mut line = String::new();
474 match reader.read_line(&mut line) {
475 Ok(0) => Ok(None),
476 Ok(_) => Ok(Some(line)),
477 Err(e) => Err(e),
478 }
479 }
480
481 pub fn write_control(&self, msg: &str) -> io::Result<()> {
483 self.control.write_all(msg.as_bytes())?;
484 if !msg.ends_with('\n') {
485 self.control.write_all(b"\n")?;
486 }
487 self.control.flush()
488 }
489
490 #[cfg(unix)]
492 pub fn as_raw_fds(&self) -> (std::os::unix::io::RawFd, std::os::unix::io::RawFd) {
493 use std::os::unix::io::{AsFd, AsRawFd};
494 let data_guard = self.data.0.lock().unwrap();
495 let ctrl_guard = self.control.0.lock().unwrap();
496 let data_fd = match &*data_guard {
497 Stream::UdSocket(s) => s.as_fd().as_raw_fd(),
498 };
499 let ctrl_fd = match &*ctrl_guard {
500 Stream::UdSocket(s) => s.as_fd().as_raw_fd(),
501 };
502 (data_fd, ctrl_fd)
503 }
504}
505
506#[cfg(test)]
507mod tests {
508 use super::*;
509 use tempfile::TempDir;
510
511 #[test]
512 fn test_socket_paths_encode_working_dir() {
513 let paths = SocketPaths::for_working_dir(Path::new("/home/user/project")).unwrap();
514 assert!(paths.data.to_string_lossy().contains("home_user_project"));
516 assert!(paths.data.to_string_lossy().ends_with(".data.sock"));
517 assert!(paths.control.to_string_lossy().ends_with(".ctrl.sock"));
518 }
519
520 #[test]
521 fn test_named_session_uses_name_directly() {
522 let temp_dir = TempDir::new().unwrap();
523 let paths = SocketPaths::for_session_name_in_dir("my-session", temp_dir.path());
524 assert!(paths
525 .data
526 .to_string_lossy()
527 .contains("my-session.data.sock"));
528 assert!(paths
529 .control
530 .to_string_lossy()
531 .contains("my-session.ctrl.sock"));
532 }
533
534 #[test]
535 fn test_exists_returns_false_for_missing_sockets() {
536 let temp_dir = TempDir::new().unwrap();
537 let paths = SocketPaths::for_session_name_in_dir("test-session", temp_dir.path());
538 assert!(!paths.exists());
539 }
540
541 #[test]
542 fn test_cleanup_succeeds_on_missing_files() {
543 let temp_dir = TempDir::new().unwrap();
544 let paths = SocketPaths::for_session_name_in_dir("test-session", temp_dir.path());
545 assert!(paths.cleanup().is_ok());
547 }
548
549 #[test]
550 fn test_socket_directory_creates_dir() {
551 let dir = SocketPaths::socket_directory().unwrap();
552 assert!(dir.exists());
553 assert!(dir.is_dir());
554 }
555
556 #[test]
557 fn test_different_working_dirs_get_different_paths() {
558 let paths1 = SocketPaths::for_working_dir(Path::new("/home/user/project1")).unwrap();
559 let paths2 = SocketPaths::for_working_dir(Path::new("/home/user/project2")).unwrap();
560 assert_ne!(paths1.data, paths2.data);
561 assert_ne!(paths1.control, paths2.control);
562 }
563
564 #[test]
565 fn test_same_working_dir_gets_same_paths() {
566 let paths1 = SocketPaths::for_working_dir(Path::new("/home/user/project")).unwrap();
567 let paths2 = SocketPaths::for_working_dir(Path::new("/home/user/project")).unwrap();
568 assert_eq!(paths1.data, paths2.data);
569 assert_eq!(paths1.control, paths2.control);
570 }
571
572 #[test]
573 fn test_pid_file_path_included() {
574 let temp_dir = TempDir::new().unwrap();
575 let paths = SocketPaths::for_session_name_in_dir("pid-test-session", temp_dir.path());
576 assert!(paths.pid.to_string_lossy().contains("pid-test-session.pid"));
577 }
578
579 #[test]
580 fn test_write_and_read_pid() {
581 let temp_dir = TempDir::new().unwrap();
582 let paths = SocketPaths::for_session_name_in_dir("test-session", temp_dir.path());
583
584 paths.write_pid(12345).unwrap();
586 assert!(paths.pid.exists());
587
588 let pid = paths.read_pid().unwrap();
590 assert_eq!(pid, Some(12345));
591
592 paths.cleanup().unwrap();
594 assert!(!paths.pid.exists());
595 }
596
597 #[test]
598 fn test_read_pid_returns_none_for_missing_file() {
599 let temp_dir = TempDir::new().unwrap();
600 let paths = SocketPaths::for_session_name_in_dir("test-session", temp_dir.path());
601 assert_eq!(paths.read_pid().unwrap(), None);
602 }
603
604 #[test]
605 fn test_cleanup_if_stale_with_no_sockets() {
606 let temp_dir = TempDir::new().unwrap();
607 let paths = SocketPaths::for_session_name_in_dir("test-session", temp_dir.path());
608
609 assert!(!paths.cleanup_if_stale());
611 }
612
613 #[cfg(unix)]
634 #[test]
635 fn test_write_control_delivers_large_message_when_nonblocking() {
636 use std::sync::mpsc;
637 use std::thread;
638 use std::time::Duration;
639
640 let temp_dir = TempDir::new().unwrap();
641 let paths = SocketPaths::for_session_name_in_dir("ctrl-large-write", temp_dir.path());
642 let mut listener = ServerListener::bind(paths.clone()).unwrap();
643
644 let big_text = "X".repeat(4 * 1024 * 1024);
647 let msg = serde_json::to_string(&crate::server::protocol::ServerControl::SetClipboard {
648 text: big_text.clone(),
649 use_osc52: true,
650 use_system_clipboard: true,
651 })
652 .unwrap();
653
654 let (connected_tx, connected_rx) = mpsc::channel::<()>();
655 let (result_tx, result_rx) = mpsc::channel::<io::Result<Option<String>>>();
656 let paths_client = paths.clone();
657 let reader = thread::spawn(move || {
658 let conn = ClientConnection::connect(&paths_client).unwrap();
659 connected_tx.send(()).unwrap();
665 thread::sleep(Duration::from_millis(300));
666 let received = conn.read_control();
667 #[allow(clippy::let_underscore_must_use)]
668 let _ = result_tx.send(received);
669 });
670
671 let server_conn = loop {
673 if let Some(c) = listener.accept().unwrap() {
674 break c;
675 }
676 thread::sleep(Duration::from_millis(5));
677 };
678
679 server_conn.control.set_nonblocking(true).unwrap();
681
682 connected_rx.recv().unwrap();
688 #[allow(clippy::let_underscore_must_use)]
689 let _ = server_conn.write_control(&msg);
690
691 let received = result_rx
695 .recv()
696 .expect("reader thread dropped the channel")
697 .expect("control read failed")
698 .expect("control stream closed unexpectedly");
699
700 match serde_json::from_str::<crate::server::protocol::ServerControl>(received.trim())
701 .expect("received control message should be valid JSON")
702 {
703 crate::server::protocol::ServerControl::SetClipboard { text, .. } => {
704 assert_eq!(
705 text.len(),
706 big_text.len(),
707 "the full clipboard payload must be delivered intact"
708 );
709 }
710 other => panic!("unexpected control message: {:?}", other),
711 }
712
713 reader.join().unwrap();
714 }
715}