use std::io::{self, BufRead, BufReader, Read, Write};
use std::path::{Path, PathBuf};
use std::sync::{Arc, Mutex};
use interprocess::local_socket::{
prelude::*, Listener, ListenerNonblockingMode, ListenerOptions, Stream,
};
use crate::workspace::encode_path_for_filename;
#[cfg(unix)]
mod platform_unix;
#[cfg(windows)]
mod platform_windows;
#[cfg(unix)]
use platform_unix as platform;
#[cfg(windows)]
use platform_windows as platform;
#[derive(Debug, Clone)]
pub struct SocketPaths {
pub data: PathBuf,
pub control: PathBuf,
pub pid: PathBuf,
}
impl SocketPaths {
pub fn socket_directory() -> io::Result<PathBuf> {
platform::get_socket_dir()
}
pub fn for_working_dir(working_dir: &Path) -> io::Result<Self> {
let socket_dir = platform::get_socket_dir()?;
let encoded = encode_path_for_filename(working_dir);
Ok(Self {
data: socket_dir.join(format!("{}.data.sock", encoded)),
control: socket_dir.join(format!("{}.ctrl.sock", encoded)),
pid: socket_dir.join(format!("{}.pid", encoded)),
})
}
pub fn for_session_name(name: &str) -> io::Result<Self> {
let socket_dir = platform::get_socket_dir()?;
Ok(Self::for_session_name_in_dir(name, &socket_dir))
}
pub fn for_session_name_in_dir(name: &str, socket_dir: &Path) -> Self {
Self {
data: socket_dir.join(format!("{}.data.sock", name)),
control: socket_dir.join(format!("{}.ctrl.sock", name)),
pid: socket_dir.join(format!("{}.pid", name)),
}
}
pub fn exists(&self) -> bool {
self.data.exists() && self.control.exists()
}
pub fn write_pid(&self, pid: u32) -> io::Result<()> {
std::fs::write(&self.pid, pid.to_string())
}
pub fn read_pid(&self) -> io::Result<Option<u32>> {
if !self.pid.exists() {
return Ok(None);
}
let content = std::fs::read_to_string(&self.pid)?;
Ok(content.trim().parse().ok())
}
pub fn is_server_alive(&self) -> bool {
use crate::server::daemon::is_process_running;
if let Ok(Some(pid)) = self.read_pid() {
if is_process_running(pid) {
return true;
}
}
if self.exists() {
return platform::check_server_by_connect(&self.control);
}
false
}
pub fn cleanup_if_stale(&self) -> bool {
if self.exists() && !self.is_server_alive() {
#[allow(clippy::let_underscore_must_use)]
let _ = self.cleanup();
true
} else {
false
}
}
pub fn cleanup(&self) -> io::Result<()> {
if self.data.exists() {
std::fs::remove_file(&self.data)?;
}
if self.control.exists() {
std::fs::remove_file(&self.control)?;
}
if self.pid.exists() {
std::fs::remove_file(&self.pid)?;
}
Ok(())
}
}
type LocalStream = Stream;
type LocalListener = Listener;
pub struct ServerListener {
data_listener: LocalListener,
control_listener: LocalListener,
paths: SocketPaths,
}
impl ServerListener {
pub fn bind(paths: SocketPaths) -> io::Result<Self> {
tracing::debug!("ServerListener::bind starting for {:?}", paths.data);
paths.cleanup()?;
if let Some(parent) = paths.data.parent() {
tracing::debug!("Creating socket directory: {:?}", parent);
std::fs::create_dir_all(parent)?;
}
let data_name = platform::socket_name_for_path(&paths.data)?;
let control_name = platform::socket_name_for_path(&paths.control)?;
tracing::debug!("Creating data listener...");
let data_listener = ListenerOptions::new()
.name(data_name)
.create_sync()
.map_err(|e| {
tracing::error!("Failed to create data listener: {}", e);
io::Error::new(io::ErrorKind::AddrInUse, e.to_string())
})?;
tracing::debug!("Creating control listener...");
let control_listener = ListenerOptions::new()
.name(control_name)
.create_sync()
.map_err(|e| {
tracing::error!("Failed to create control listener: {}", e);
io::Error::new(io::ErrorKind::AddrInUse, e.to_string())
})?;
#[cfg(windows)]
{
tracing::debug!("Writing marker files...");
std::fs::write(&paths.data, "socket")?;
std::fs::write(&paths.control, "socket")?;
}
tracing::info!("Server listening on {:?}", paths.data);
Ok(Self {
data_listener,
control_listener,
paths,
})
}
pub fn accept(&mut self) -> io::Result<Option<ServerConnection>> {
self.control_listener
.set_nonblocking(ListenerNonblockingMode::Accept)?;
let control_stream = match self.control_listener.accept() {
Ok(stream) => stream,
Err(e) if e.kind() == io::ErrorKind::WouldBlock => {
return Ok(None);
}
#[cfg(windows)]
Err(e) if platform_windows::is_transient_pipe_error(&e) => {
return Ok(None);
}
Err(e) => return Err(e),
};
self.data_listener
.set_nonblocking(ListenerNonblockingMode::Neither)?;
let data_stream = self.data_listener.accept()?;
#[cfg(not(windows))]
{
#[allow(clippy::let_underscore_must_use)]
let _ = data_stream.set_nonblocking(true);
control_stream.set_nonblocking(true)?;
}
Ok(Some(ServerConnection {
data: StreamWrapper::new(data_stream),
control: StreamWrapper::new(control_stream),
}))
}
pub fn paths(&self) -> &SocketPaths {
&self.paths
}
}
impl Drop for ServerListener {
fn drop(&mut self) {
#[allow(clippy::let_underscore_must_use)]
let _ = self.paths.cleanup();
}
}
#[derive(Clone)]
pub struct StreamWrapper(Arc<Mutex<LocalStream>>);
impl StreamWrapper {
fn new(stream: LocalStream) -> Self {
Self(Arc::new(Mutex::new(stream)))
}
pub fn set_nonblocking(&self, nonblocking: bool) -> io::Result<()> {
self.0
.lock()
.map_err(|_| io::Error::other("mutex poisoned"))?
.set_nonblocking(nonblocking)
}
pub fn write_all(&self, buf: &[u8]) -> io::Result<()> {
let mut guard = self
.0
.lock()
.map_err(|_| io::Error::other("mutex poisoned"))?;
Write::write_all(&mut *guard, buf)
}
pub fn flush(&self) -> io::Result<()> {
let mut guard = self
.0
.lock()
.map_err(|_| io::Error::other("mutex poisoned"))?;
Write::flush(&mut *guard)
}
pub fn try_read(&self, buf: &mut [u8]) -> io::Result<usize> {
let mut guard = match self.0.try_lock() {
Ok(g) => g,
Err(std::sync::TryLockError::WouldBlock) => {
return Err(io::Error::new(
io::ErrorKind::WouldBlock,
"stream busy (mutex contended)",
));
}
Err(std::sync::TryLockError::Poisoned(_)) => {
return Err(io::Error::other("mutex poisoned"));
}
};
platform::try_read_nonblocking(&mut guard, buf)
}
}
#[inline]
fn map_windows_pipe_error(result: io::Result<usize>) -> io::Result<usize> {
match result {
#[cfg(windows)]
Err(e) if platform_windows::is_transient_pipe_error(&e) => {
Err(io::Error::new(io::ErrorKind::WouldBlock, e))
}
other => other,
}
}
impl Read for StreamWrapper {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
let result = self
.0
.lock()
.map_err(|_| io::Error::other("mutex poisoned"))?
.read(buf);
map_windows_pipe_error(result)
}
}
impl Read for &StreamWrapper {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
let result = self
.0
.lock()
.map_err(|_| io::Error::other("mutex poisoned"))?
.read(buf);
map_windows_pipe_error(result)
}
}
pub struct ServerConnection {
pub data: StreamWrapper,
pub control: StreamWrapper,
}
impl ServerConnection {
pub fn read_data(&self, buf: &mut [u8]) -> io::Result<usize> {
self.data.try_read(buf)
}
pub fn write_data(&self, buf: &[u8]) -> io::Result<()> {
self.data.write_all(buf)?;
self.data.flush()
}
pub fn read_control(&self) -> io::Result<Option<String>> {
#[cfg(not(windows))]
self.control.set_nonblocking(false)?;
let mut reader = BufReader::new(&self.control);
let mut line = String::new();
match reader.read_line(&mut line) {
Ok(0) => Ok(None), Ok(_) => Ok(Some(line)),
Err(e) => Err(e),
}
}
pub fn write_control(&self, msg: &str) -> io::Result<()> {
self.control.write_all(msg.as_bytes())?;
if !msg.ends_with('\n') {
self.control.write_all(b"\n")?;
}
self.control.flush()
}
}
pub struct ClientConnection {
pub data: StreamWrapper,
pub control: StreamWrapper,
}
impl ClientConnection {
pub fn connect(paths: &SocketPaths) -> io::Result<Self> {
let control_name = platform::socket_name_for_path(&paths.control)?;
let data_name = platform::socket_name_for_path(&paths.data)?;
let control = Stream::connect(control_name)
.map_err(|e| io::Error::new(io::ErrorKind::ConnectionRefused, e.to_string()))?;
let data = Stream::connect(data_name)
.map_err(|e| io::Error::new(io::ErrorKind::ConnectionRefused, e.to_string()))?;
Ok(Self {
data: StreamWrapper::new(data),
control: StreamWrapper::new(control),
})
}
pub fn set_data_nonblocking(&self, nonblocking: bool) -> io::Result<()> {
self.data.set_nonblocking(nonblocking)
}
pub fn read_data(&self, buf: &mut [u8]) -> io::Result<usize> {
self.data.try_read(buf)
}
pub fn write_data(&self, buf: &[u8]) -> io::Result<()> {
self.data.write_all(buf)?;
self.data.flush()
}
pub fn read_control(&self) -> io::Result<Option<String>> {
let mut reader = BufReader::new(&self.control);
let mut line = String::new();
match reader.read_line(&mut line) {
Ok(0) => Ok(None),
Ok(_) => Ok(Some(line)),
Err(e) => Err(e),
}
}
pub fn write_control(&self, msg: &str) -> io::Result<()> {
self.control.write_all(msg.as_bytes())?;
if !msg.ends_with('\n') {
self.control.write_all(b"\n")?;
}
self.control.flush()
}
#[cfg(unix)]
pub fn as_raw_fds(&self) -> (std::os::unix::io::RawFd, std::os::unix::io::RawFd) {
use std::os::unix::io::{AsFd, AsRawFd};
let data_guard = self.data.0.lock().unwrap();
let ctrl_guard = self.control.0.lock().unwrap();
let data_fd = match &*data_guard {
Stream::UdSocket(s) => s.as_fd().as_raw_fd(),
};
let ctrl_fd = match &*ctrl_guard {
Stream::UdSocket(s) => s.as_fd().as_raw_fd(),
};
(data_fd, ctrl_fd)
}
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::TempDir;
#[test]
fn test_socket_paths_encode_working_dir() {
let paths = SocketPaths::for_working_dir(Path::new("/home/user/project")).unwrap();
assert!(paths.data.to_string_lossy().contains("home_user_project"));
assert!(paths.data.to_string_lossy().ends_with(".data.sock"));
assert!(paths.control.to_string_lossy().ends_with(".ctrl.sock"));
}
#[test]
fn test_named_session_uses_name_directly() {
let temp_dir = TempDir::new().unwrap();
let paths = SocketPaths::for_session_name_in_dir("my-session", temp_dir.path());
assert!(paths
.data
.to_string_lossy()
.contains("my-session.data.sock"));
assert!(paths
.control
.to_string_lossy()
.contains("my-session.ctrl.sock"));
}
#[test]
fn test_exists_returns_false_for_missing_sockets() {
let temp_dir = TempDir::new().unwrap();
let paths = SocketPaths::for_session_name_in_dir("test-session", temp_dir.path());
assert!(!paths.exists());
}
#[test]
fn test_cleanup_succeeds_on_missing_files() {
let temp_dir = TempDir::new().unwrap();
let paths = SocketPaths::for_session_name_in_dir("test-session", temp_dir.path());
assert!(paths.cleanup().is_ok());
}
#[test]
fn test_socket_directory_creates_dir() {
let dir = SocketPaths::socket_directory().unwrap();
assert!(dir.exists());
assert!(dir.is_dir());
}
#[test]
fn test_different_working_dirs_get_different_paths() {
let paths1 = SocketPaths::for_working_dir(Path::new("/home/user/project1")).unwrap();
let paths2 = SocketPaths::for_working_dir(Path::new("/home/user/project2")).unwrap();
assert_ne!(paths1.data, paths2.data);
assert_ne!(paths1.control, paths2.control);
}
#[test]
fn test_same_working_dir_gets_same_paths() {
let paths1 = SocketPaths::for_working_dir(Path::new("/home/user/project")).unwrap();
let paths2 = SocketPaths::for_working_dir(Path::new("/home/user/project")).unwrap();
assert_eq!(paths1.data, paths2.data);
assert_eq!(paths1.control, paths2.control);
}
#[test]
fn test_pid_file_path_included() {
let temp_dir = TempDir::new().unwrap();
let paths = SocketPaths::for_session_name_in_dir("pid-test-session", temp_dir.path());
assert!(paths.pid.to_string_lossy().contains("pid-test-session.pid"));
}
#[test]
fn test_write_and_read_pid() {
let temp_dir = TempDir::new().unwrap();
let paths = SocketPaths::for_session_name_in_dir("test-session", temp_dir.path());
paths.write_pid(12345).unwrap();
assert!(paths.pid.exists());
let pid = paths.read_pid().unwrap();
assert_eq!(pid, Some(12345));
paths.cleanup().unwrap();
assert!(!paths.pid.exists());
}
#[test]
fn test_read_pid_returns_none_for_missing_file() {
let temp_dir = TempDir::new().unwrap();
let paths = SocketPaths::for_session_name_in_dir("test-session", temp_dir.path());
assert_eq!(paths.read_pid().unwrap(), None);
}
#[test]
fn test_cleanup_if_stale_with_no_sockets() {
let temp_dir = TempDir::new().unwrap();
let paths = SocketPaths::for_session_name_in_dir("test-session", temp_dir.path());
assert!(!paths.cleanup_if_stale());
}
}