Skip to main content

fresh/server/ipc/
mod.rs

1//! IPC infrastructure for client-server communication
2//!
3//! Uses the `interprocess` crate for cross-platform local sockets:
4//! - Unix domain sockets on Linux/macOS
5//! - Named pipes on Windows
6//!
7//! Each session has two sockets: data (byte stream) and control (JSON messages).
8
9use std::io::{self, BufRead, BufReader, Read, Write};
10use std::path::{Path, PathBuf};
11use std::sync::{Arc, Mutex};
12
13use interprocess::local_socket::{
14    prelude::*, Listener, ListenerNonblockingMode, ListenerOptions, Stream,
15};
16
17use crate::workspace::encode_path_for_filename;
18
19// Platform-specific implementations
20#[cfg(unix)]
21mod platform_unix;
22#[cfg(windows)]
23mod platform_windows;
24
25#[cfg(unix)]
26use platform_unix as platform;
27#[cfg(windows)]
28use platform_windows as platform;
29
30/// Socket paths for a session
31#[derive(Debug, Clone)]
32pub struct SocketPaths {
33    /// Data socket path (raw byte stream)
34    pub data: PathBuf,
35    /// Control socket path (JSON messages)
36    pub control: PathBuf,
37    /// PID file path (for detecting stale sessions)
38    pub pid: PathBuf,
39}
40
41impl SocketPaths {
42    /// Get the socket directory
43    pub fn socket_directory() -> io::Result<PathBuf> {
44        platform::get_socket_dir()
45    }
46
47    /// Get socket paths for a working directory
48    pub fn for_working_dir(working_dir: &Path) -> io::Result<Self> {
49        let socket_dir = platform::get_socket_dir()?;
50        let encoded = encode_path_for_filename(working_dir);
51
52        Ok(Self {
53            data: socket_dir.join(format!("{}.data.sock", encoded)),
54            control: socket_dir.join(format!("{}.ctrl.sock", encoded)),
55            pid: socket_dir.join(format!("{}.pid", encoded)),
56        })
57    }
58
59    /// Get socket paths for a named session
60    pub fn for_session_name(name: &str) -> io::Result<Self> {
61        let socket_dir = platform::get_socket_dir()?;
62        Ok(Self::for_session_name_in_dir(name, &socket_dir))
63    }
64
65    /// Get socket paths for a named session in a specific directory
66    /// (primarily for testing with isolated temp directories)
67    pub fn for_session_name_in_dir(name: &str, socket_dir: &Path) -> Self {
68        Self {
69            data: socket_dir.join(format!("{}.data.sock", name)),
70            control: socket_dir.join(format!("{}.ctrl.sock", name)),
71            pid: socket_dir.join(format!("{}.pid", name)),
72        }
73    }
74
75    /// Check if the sockets exist (server might be running)
76    pub fn exists(&self) -> bool {
77        self.data.exists() && self.control.exists()
78    }
79
80    /// Write the server PID to the PID file
81    pub fn write_pid(&self, pid: u32) -> io::Result<()> {
82        std::fs::write(&self.pid, pid.to_string())
83    }
84
85    /// Read the server PID from the PID file
86    pub fn read_pid(&self) -> io::Result<Option<u32>> {
87        if !self.pid.exists() {
88            return Ok(None);
89        }
90        let content = std::fs::read_to_string(&self.pid)?;
91        Ok(content.trim().parse().ok())
92    }
93
94    /// Check if the server process is still alive
95    pub fn is_server_alive(&self) -> bool {
96        use crate::server::daemon::is_process_running;
97
98        // Check PID file - this is the reliable method
99        if let Ok(Some(pid)) = self.read_pid() {
100            if is_process_running(pid) {
101                return true;
102            }
103        }
104
105        // Platform-specific fallback check
106        if self.exists() {
107            return platform::check_server_by_connect(&self.control);
108        }
109
110        false
111    }
112
113    /// Clean up stale session files if server is not running
114    /// Returns true if files were cleaned up
115    pub fn cleanup_if_stale(&self) -> bool {
116        if self.exists() && !self.is_server_alive() {
117            // Best-effort cleanup of stale socket files
118            #[allow(clippy::let_underscore_must_use)]
119            let _ = self.cleanup();
120            true
121        } else {
122            false
123        }
124    }
125
126    /// Remove socket and PID files (cleanup)
127    pub fn cleanup(&self) -> io::Result<()> {
128        if self.data.exists() {
129            std::fs::remove_file(&self.data)?;
130        }
131        if self.control.exists() {
132            std::fs::remove_file(&self.control)?;
133        }
134        if self.pid.exists() {
135            std::fs::remove_file(&self.pid)?;
136        }
137        Ok(())
138    }
139}
140
141/// Type alias for interprocess local socket stream
142type LocalStream = Stream;
143type LocalListener = Listener;
144
145/// Server listener for accepting client connections
146pub struct ServerListener {
147    data_listener: LocalListener,
148    control_listener: LocalListener,
149    paths: SocketPaths,
150}
151
152impl ServerListener {
153    /// Create a new server listener for the given socket paths
154    pub fn bind(paths: SocketPaths) -> io::Result<Self> {
155        tracing::debug!("ServerListener::bind starting for {:?}", paths.data);
156
157        // Clean up any stale sockets
158        paths.cleanup()?;
159
160        // Ensure socket directory exists
161        if let Some(parent) = paths.data.parent() {
162            tracing::debug!("Creating socket directory: {:?}", parent);
163            std::fs::create_dir_all(parent)?;
164        }
165
166        let data_name = platform::socket_name_for_path(&paths.data)?;
167        let control_name = platform::socket_name_for_path(&paths.control)?;
168
169        tracing::debug!("Creating data listener...");
170        let data_listener = ListenerOptions::new()
171            .name(data_name)
172            .create_sync()
173            .map_err(|e| {
174                tracing::error!("Failed to create data listener: {}", e);
175                io::Error::new(io::ErrorKind::AddrInUse, e.to_string())
176            })?;
177
178        tracing::debug!("Creating control listener...");
179        let control_listener = ListenerOptions::new()
180            .name(control_name)
181            .create_sync()
182            .map_err(|e| {
183                tracing::error!("Failed to create control listener: {}", e);
184                io::Error::new(io::ErrorKind::AddrInUse, e.to_string())
185            })?;
186
187        // Write marker files so exists() check works on Windows
188        // (Unix domain sockets already create socket files on the filesystem)
189        #[cfg(windows)]
190        {
191            tracing::debug!("Writing marker files...");
192            std::fs::write(&paths.data, "socket")?;
193            std::fs::write(&paths.control, "socket")?;
194        }
195
196        tracing::info!("Server listening on {:?}", paths.data);
197
198        Ok(Self {
199            data_listener,
200            control_listener,
201            paths,
202        })
203    }
204
205    /// Accept a new client connection (both data and control sockets)
206    /// Returns None if no connection is pending
207    pub fn accept(&mut self) -> io::Result<Option<ServerConnection>> {
208        // Try to accept on control socket first (client connects here first)
209        // Use set_nonblocking for non-blocking accept
210        self.control_listener
211            .set_nonblocking(ListenerNonblockingMode::Accept)?;
212
213        let control_stream = match self.control_listener.accept() {
214            Ok(stream) => stream,
215            Err(e) if e.kind() == io::ErrorKind::WouldBlock => {
216                return Ok(None);
217            }
218            #[cfg(windows)]
219            Err(e) if platform_windows::is_transient_pipe_error(&e) => {
220                return Ok(None);
221            }
222            Err(e) => return Err(e),
223        };
224
225        // Now wait for data socket connection (blocking)
226        self.data_listener
227            .set_nonblocking(ListenerNonblockingMode::Neither)?;
228        let data_stream = self.data_listener.accept()?;
229
230        // On Windows, DON'T set nonblocking here - the try_read() function handles it
231        // Setting nonblocking early can cause issues with named pipes where read()
232        // returns Ok(0) when no data is available (interpreted as EOF).
233        #[cfg(not(windows))]
234        {
235            // Set data stream to nonblocking for polling (Unix only)
236            #[allow(clippy::let_underscore_must_use)]
237            let _ = data_stream.set_nonblocking(true);
238            control_stream.set_nonblocking(true)?;
239        }
240
241        Ok(Some(ServerConnection {
242            data: StreamWrapper::new(data_stream),
243            control: StreamWrapper::new(control_stream),
244        }))
245    }
246
247    /// Get the socket paths
248    pub fn paths(&self) -> &SocketPaths {
249        &self.paths
250    }
251}
252
253impl Drop for ServerListener {
254    fn drop(&mut self) {
255        // Best-effort cleanup of socket files on shutdown
256        #[allow(clippy::let_underscore_must_use)]
257        let _ = self.paths.cleanup();
258    }
259}
260
261/// Wrapper for LocalSocketStream that provides thread-safe sharing
262/// Uses Arc<Mutex<>> internally to allow cloning and use across threads
263#[derive(Clone)]
264pub struct StreamWrapper(Arc<Mutex<LocalStream>>);
265
266impl StreamWrapper {
267    /// Create a new StreamWrapper from a LocalStream
268    fn new(stream: LocalStream) -> Self {
269        Self(Arc::new(Mutex::new(stream)))
270    }
271
272    /// Set non-blocking mode
273    pub fn set_nonblocking(&self, nonblocking: bool) -> io::Result<()> {
274        self.0
275            .lock()
276            .map_err(|_| io::Error::other("mutex poisoned"))?
277            .set_nonblocking(nonblocking)
278    }
279
280    /// Write all bytes (takes &self for thread sharing)
281    pub fn write_all(&self, buf: &[u8]) -> io::Result<()> {
282        let mut guard = self
283            .0
284            .lock()
285            .map_err(|_| io::Error::other("mutex poisoned"))?;
286        Write::write_all(&mut *guard, buf)
287    }
288
289    /// Flush the stream
290    pub fn flush(&self) -> io::Result<()> {
291        let mut guard = self
292            .0
293            .lock()
294            .map_err(|_| io::Error::other("mutex poisoned"))?;
295        Write::flush(&mut *guard)
296    }
297
298    /// Try to read without blocking (returns WouldBlock if no data or if mutex is contended)
299    pub fn try_read(&self, buf: &mut [u8]) -> io::Result<usize> {
300        let mut guard = match self.0.try_lock() {
301            Ok(g) => g,
302            Err(std::sync::TryLockError::WouldBlock) => {
303                return Err(io::Error::new(
304                    io::ErrorKind::WouldBlock,
305                    "stream busy (mutex contended)",
306                ));
307            }
308            Err(std::sync::TryLockError::Poisoned(_)) => {
309                return Err(io::Error::other("mutex poisoned"));
310            }
311        };
312
313        platform::try_read_nonblocking(&mut guard, buf)
314    }
315}
316
317/// Helper to map Windows pipe errors to WouldBlock
318#[inline]
319fn map_windows_pipe_error(result: io::Result<usize>) -> io::Result<usize> {
320    match result {
321        #[cfg(windows)]
322        Err(e) if platform_windows::is_transient_pipe_error(&e) => {
323            Err(io::Error::new(io::ErrorKind::WouldBlock, e))
324        }
325        other => other,
326    }
327}
328
329impl Read for StreamWrapper {
330    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
331        let result = self
332            .0
333            .lock()
334            .map_err(|_| io::Error::other("mutex poisoned"))?
335            .read(buf);
336        map_windows_pipe_error(result)
337    }
338}
339
340impl Read for &StreamWrapper {
341    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
342        let result = self
343            .0
344            .lock()
345            .map_err(|_| io::Error::other("mutex poisoned"))?
346            .read(buf);
347        map_windows_pipe_error(result)
348    }
349}
350
351/// A client connection (from the server's perspective)
352pub struct ServerConnection {
353    /// Data stream for raw byte stream
354    pub data: StreamWrapper,
355    /// Control stream for JSON messages
356    pub control: StreamWrapper,
357}
358
359impl ServerConnection {
360    /// Read available data from the data socket (non-blocking)
361    pub fn read_data(&self, buf: &mut [u8]) -> io::Result<usize> {
362        self.data.try_read(buf)
363    }
364
365    /// Write data to the data socket
366    pub fn write_data(&self, buf: &[u8]) -> io::Result<()> {
367        self.data.write_all(buf)?;
368        self.data.flush()
369    }
370
371    /// Read a control message (blocking)
372    pub fn read_control(&self) -> io::Result<Option<String>> {
373        // On Windows, don't toggle blocking mode - named pipes don't support mode
374        // switching after connection. The pipe should already be in blocking mode.
375        #[cfg(not(windows))]
376        self.control.set_nonblocking(false)?;
377        let mut reader = BufReader::new(&self.control);
378        let mut line = String::new();
379        match reader.read_line(&mut line) {
380            Ok(0) => Ok(None), // EOF
381            Ok(_) => Ok(Some(line)),
382            Err(e) => Err(e),
383        }
384    }
385
386    /// Write a control message
387    pub fn write_control(&self, msg: &str) -> io::Result<()> {
388        self.control.write_all(msg.as_bytes())?;
389        if !msg.ends_with('\n') {
390            self.control.write_all(b"\n")?;
391        }
392        self.control.flush()
393    }
394}
395
396/// Client connection to server
397pub struct ClientConnection {
398    /// Data stream for raw byte stream
399    pub data: StreamWrapper,
400    /// Control stream for JSON messages
401    pub control: StreamWrapper,
402}
403
404impl ClientConnection {
405    /// Connect to a server at the given socket paths
406    pub fn connect(paths: &SocketPaths) -> io::Result<Self> {
407        let control_name = platform::socket_name_for_path(&paths.control)?;
408        let data_name = platform::socket_name_for_path(&paths.data)?;
409
410        // Connect control socket first, then data (matching server's accept order)
411        let control = Stream::connect(control_name)
412            .map_err(|e| io::Error::new(io::ErrorKind::ConnectionRefused, e.to_string()))?;
413
414        let data = Stream::connect(data_name)
415            .map_err(|e| io::Error::new(io::ErrorKind::ConnectionRefused, e.to_string()))?;
416
417        Ok(Self {
418            data: StreamWrapper::new(data),
419            control: StreamWrapper::new(control),
420        })
421    }
422
423    /// Set data socket to non-blocking mode
424    pub fn set_data_nonblocking(&self, nonblocking: bool) -> io::Result<()> {
425        self.data.set_nonblocking(nonblocking)
426    }
427
428    /// Read from data socket
429    pub fn read_data(&self, buf: &mut [u8]) -> io::Result<usize> {
430        self.data.try_read(buf)
431    }
432
433    /// Write to data socket
434    pub fn write_data(&self, buf: &[u8]) -> io::Result<()> {
435        self.data.write_all(buf)?;
436        self.data.flush()
437    }
438
439    /// Read a control message
440    pub fn read_control(&self) -> io::Result<Option<String>> {
441        let mut reader = BufReader::new(&self.control);
442        let mut line = String::new();
443        match reader.read_line(&mut line) {
444            Ok(0) => Ok(None),
445            Ok(_) => Ok(Some(line)),
446            Err(e) => Err(e),
447        }
448    }
449
450    /// Write a control message
451    pub fn write_control(&self, msg: &str) -> io::Result<()> {
452        self.control.write_all(msg.as_bytes())?;
453        if !msg.ends_with('\n') {
454            self.control.write_all(b"\n")?;
455        }
456        self.control.flush()
457    }
458
459    /// Get the raw file descriptors for use with poll/select (Unix only)
460    #[cfg(unix)]
461    pub fn as_raw_fds(&self) -> (std::os::unix::io::RawFd, std::os::unix::io::RawFd) {
462        use std::os::unix::io::{AsFd, AsRawFd};
463        let data_guard = self.data.0.lock().unwrap();
464        let ctrl_guard = self.control.0.lock().unwrap();
465        let data_fd = match &*data_guard {
466            Stream::UdSocket(s) => s.as_fd().as_raw_fd(),
467        };
468        let ctrl_fd = match &*ctrl_guard {
469            Stream::UdSocket(s) => s.as_fd().as_raw_fd(),
470        };
471        (data_fd, ctrl_fd)
472    }
473}
474
475#[cfg(test)]
476mod tests {
477    use super::*;
478    use tempfile::TempDir;
479
480    #[test]
481    fn test_socket_paths_encode_working_dir() {
482        let paths = SocketPaths::for_working_dir(Path::new("/home/user/project")).unwrap();
483        // Should encode path separators
484        assert!(paths.data.to_string_lossy().contains("home_user_project"));
485        assert!(paths.data.to_string_lossy().ends_with(".data.sock"));
486        assert!(paths.control.to_string_lossy().ends_with(".ctrl.sock"));
487    }
488
489    #[test]
490    fn test_named_session_uses_name_directly() {
491        let temp_dir = TempDir::new().unwrap();
492        let paths = SocketPaths::for_session_name_in_dir("my-session", temp_dir.path());
493        assert!(paths
494            .data
495            .to_string_lossy()
496            .contains("my-session.data.sock"));
497        assert!(paths
498            .control
499            .to_string_lossy()
500            .contains("my-session.ctrl.sock"));
501    }
502
503    #[test]
504    fn test_exists_returns_false_for_missing_sockets() {
505        let temp_dir = TempDir::new().unwrap();
506        let paths = SocketPaths::for_session_name_in_dir("test-session", temp_dir.path());
507        assert!(!paths.exists());
508    }
509
510    #[test]
511    fn test_cleanup_succeeds_on_missing_files() {
512        let temp_dir = TempDir::new().unwrap();
513        let paths = SocketPaths::for_session_name_in_dir("test-session", temp_dir.path());
514        // Should not error when files don't exist
515        assert!(paths.cleanup().is_ok());
516    }
517
518    #[test]
519    fn test_socket_directory_creates_dir() {
520        let dir = SocketPaths::socket_directory().unwrap();
521        assert!(dir.exists());
522        assert!(dir.is_dir());
523    }
524
525    #[test]
526    fn test_different_working_dirs_get_different_paths() {
527        let paths1 = SocketPaths::for_working_dir(Path::new("/home/user/project1")).unwrap();
528        let paths2 = SocketPaths::for_working_dir(Path::new("/home/user/project2")).unwrap();
529        assert_ne!(paths1.data, paths2.data);
530        assert_ne!(paths1.control, paths2.control);
531    }
532
533    #[test]
534    fn test_same_working_dir_gets_same_paths() {
535        let paths1 = SocketPaths::for_working_dir(Path::new("/home/user/project")).unwrap();
536        let paths2 = SocketPaths::for_working_dir(Path::new("/home/user/project")).unwrap();
537        assert_eq!(paths1.data, paths2.data);
538        assert_eq!(paths1.control, paths2.control);
539    }
540
541    #[test]
542    fn test_pid_file_path_included() {
543        let temp_dir = TempDir::new().unwrap();
544        let paths = SocketPaths::for_session_name_in_dir("pid-test-session", temp_dir.path());
545        assert!(paths.pid.to_string_lossy().contains("pid-test-session.pid"));
546    }
547
548    #[test]
549    fn test_write_and_read_pid() {
550        let temp_dir = TempDir::new().unwrap();
551        let paths = SocketPaths::for_session_name_in_dir("test-session", temp_dir.path());
552
553        // Write PID
554        paths.write_pid(12345).unwrap();
555        assert!(paths.pid.exists());
556
557        // Read PID
558        let pid = paths.read_pid().unwrap();
559        assert_eq!(pid, Some(12345));
560
561        // Clean up
562        paths.cleanup().unwrap();
563        assert!(!paths.pid.exists());
564    }
565
566    #[test]
567    fn test_read_pid_returns_none_for_missing_file() {
568        let temp_dir = TempDir::new().unwrap();
569        let paths = SocketPaths::for_session_name_in_dir("test-session", temp_dir.path());
570        assert_eq!(paths.read_pid().unwrap(), None);
571    }
572
573    #[test]
574    fn test_cleanup_if_stale_with_no_sockets() {
575        let temp_dir = TempDir::new().unwrap();
576        let paths = SocketPaths::for_session_name_in_dir("test-session", temp_dir.path());
577
578        // No sockets exist, should return false (nothing to clean)
579        assert!(!paths.cleanup_if_stale());
580    }
581}