#![cfg(unix)]
#![allow(unsafe_code)]
use std::fs;
use std::io;
use std::os::unix::fs::PermissionsExt;
use std::path::{Path, PathBuf};
use std::process;
use std::sync::mpsc::Sender;
use std::thread;
use rmcp::ServiceExt;
use crate::mcp_bridge::{McpBridgeServer, McpCommand};
const DIR_MODE: u32 = 0o700;
const SOCK_MODE: u32 = 0o600;
pub fn default_socket_path() -> Option<PathBuf> {
let base = state_dir()?;
Some(
base.join("travelagent")
.join("sessions")
.join(format!("{}.sock", process::id())),
)
}
pub fn sessions_dir() -> Option<PathBuf> {
Some(state_dir()?.join("travelagent").join("sessions"))
}
fn state_dir() -> Option<PathBuf> {
if let Some(xdg) = std::env::var_os("XDG_STATE_HOME")
&& !xdg.is_empty()
{
return Some(PathBuf::from(xdg));
}
let home = std::env::var_os("HOME")?;
if home.is_empty() {
return None;
}
Some(PathBuf::from(home).join(".local").join("state"))
}
pub fn resolve_socket_path(explicit: Option<&str>) -> io::Result<PathBuf> {
match explicit {
Some(raw) => {
let expanded = expand_tilde(raw);
Ok(expanded)
}
None => default_socket_path().ok_or_else(|| {
io::Error::new(
io::ErrorKind::NotFound,
"could not resolve default socket path: neither XDG_STATE_HOME nor HOME is set",
)
}),
}
}
fn expand_tilde(raw: &str) -> PathBuf {
if let Some(rest) = raw.strip_prefix("~/")
&& let Some(home) = std::env::var_os("HOME")
&& !home.is_empty()
{
return PathBuf::from(home).join(rest);
}
if raw == "~"
&& let Some(home) = std::env::var_os("HOME")
&& !home.is_empty()
{
return PathBuf::from(home);
}
PathBuf::from(raw)
}
pub fn sweep_stale_sockets(dir: &Path) {
let Ok(entries) = fs::read_dir(dir) else {
return;
};
for entry in entries.flatten() {
let path = entry.path();
if path.extension().and_then(|s| s.to_str()) != Some("sock") {
continue;
}
let Some(stem) = path.file_stem().and_then(|s| s.to_str()) else {
continue;
};
let Ok(pid) = stem.parse::<i32>() else {
continue;
};
if !pid_is_alive(pid) {
let _ = fs::remove_file(&path);
}
}
}
fn pid_is_alive(pid: i32) -> bool {
let ret = unsafe { libc::kill(pid, 0) };
if ret == 0 {
return true;
}
let errno = io::Error::last_os_error().raw_os_error().unwrap_or(0);
errno == libc::EPERM
}
pub struct SocketGuard {
path: PathBuf,
}
impl SocketGuard {
#[allow(dead_code)]
pub fn path(&self) -> &Path {
&self.path
}
}
impl Drop for SocketGuard {
fn drop(&mut self) {
let _ = fs::remove_file(&self.path);
}
}
pub fn spawn_mcp_socket_server(
path: PathBuf,
tx: Sender<McpCommand>,
runtime_handle: tokio::runtime::Handle,
hub: &crate::mcp_bridge::McpHub,
) -> io::Result<SocketGuard> {
if let Some(parent) = path.parent() {
fs::create_dir_all(parent)?;
let in_default_sessions_dir = sessions_dir()
.as_deref()
.is_some_and(|default| parent == default);
if in_default_sessions_dir {
let mut perms = fs::metadata(parent)?.permissions();
if perms.mode() & 0o777 != DIR_MODE {
perms.set_mode(DIR_MODE);
fs::set_permissions(parent, perms)?;
}
}
}
if path.exists() {
fs::remove_file(&path)?;
}
let listener = std::os::unix::net::UnixListener::bind(&path)?;
let mut perms = fs::metadata(&path)?.permissions();
perms.set_mode(SOCK_MODE);
fs::set_permissions(&path, perms)?;
let guard = SocketGuard { path: path.clone() };
let registry_outer = hub.registry.clone();
thread::spawn(move || {
if listener.set_nonblocking(true).is_err() {
return;
}
runtime_handle.block_on(async move {
let Ok(listener) = tokio::net::UnixListener::from_std(listener) else {
return;
};
loop {
let (stream, _addr) = match listener.accept().await {
Ok(pair) => pair,
Err(_) => continue,
};
let tx = tx.clone();
let registry = registry_outer.clone();
tokio::spawn(async move {
let (read, write) = stream.into_split();
let connection_id = registry.allocate_id();
let server = McpBridgeServer::new(tx, registry.clone(), connection_id);
let Ok(service) = server.serve((read, write)).await else {
return;
};
let _ = service.waiting().await;
registry.remove(connection_id).await;
});
}
});
});
Ok(guard)
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::{Mutex, PoisonError};
static ENV_LOCK: Mutex<()> = Mutex::new(());
fn env_lock() -> std::sync::MutexGuard<'static, ()> {
ENV_LOCK.lock().unwrap_or_else(PoisonError::into_inner)
}
#[test]
fn default_socket_path_uses_xdg_state_home_when_set() {
let _lock = env_lock();
unsafe {
std::env::set_var("XDG_STATE_HOME", "/tmp/xdg-state-trv-test");
std::env::set_var("HOME", "/tmp/home-trv-test");
}
let path = default_socket_path().expect("path should resolve");
assert!(path.starts_with("/tmp/xdg-state-trv-test/travelagent/sessions"));
let name = path.file_name().unwrap().to_string_lossy();
assert!(name.ends_with(".sock"));
assert!(name.starts_with(&process::id().to_string()));
unsafe {
std::env::remove_var("XDG_STATE_HOME");
std::env::remove_var("HOME");
}
}
#[test]
fn default_socket_path_falls_back_to_home_local_state() {
let _lock = env_lock();
unsafe {
std::env::remove_var("XDG_STATE_HOME");
std::env::set_var("HOME", "/tmp/home-trv-fb-test");
}
let path = default_socket_path().expect("path should resolve");
assert!(path.starts_with("/tmp/home-trv-fb-test/.local/state/travelagent/sessions"));
unsafe {
std::env::remove_var("HOME");
}
}
#[test]
fn default_socket_path_returns_none_without_env() {
let _lock = env_lock();
unsafe {
std::env::remove_var("XDG_STATE_HOME");
std::env::remove_var("HOME");
}
assert!(default_socket_path().is_none());
}
#[test]
fn default_socket_path_ignores_empty_xdg_state_home() {
let _lock = env_lock();
unsafe {
std::env::set_var("XDG_STATE_HOME", "");
std::env::set_var("HOME", "/tmp/home-trv-empty");
}
let path = default_socket_path().expect("path should resolve");
assert!(path.starts_with("/tmp/home-trv-empty/.local/state"));
unsafe {
std::env::remove_var("XDG_STATE_HOME");
std::env::remove_var("HOME");
}
}
#[test]
fn resolve_socket_path_passes_through_absolute() {
let p = resolve_socket_path(Some("/tmp/some/socket.sock")).unwrap();
assert_eq!(p, PathBuf::from("/tmp/some/socket.sock"));
}
#[test]
fn resolve_socket_path_passes_through_relative() {
let p = resolve_socket_path(Some("trv.sock")).unwrap();
assert_eq!(p, PathBuf::from("trv.sock"));
}
#[test]
fn resolve_socket_path_expands_tilde() {
let _lock = env_lock();
unsafe {
std::env::set_var("HOME", "/tmp/home-trv-tilde");
}
let p = resolve_socket_path(Some("~/x/y.sock")).unwrap();
assert_eq!(p, PathBuf::from("/tmp/home-trv-tilde/x/y.sock"));
let p = resolve_socket_path(Some("~")).unwrap();
assert_eq!(p, PathBuf::from("/tmp/home-trv-tilde"));
unsafe {
std::env::remove_var("HOME");
}
}
#[test]
fn resolve_socket_path_none_falls_back_to_default() {
let _lock = env_lock();
unsafe {
std::env::set_var("XDG_STATE_HOME", "/tmp/xdg-trv-resolve");
}
let p = resolve_socket_path(None).unwrap();
assert!(p.starts_with("/tmp/xdg-trv-resolve/travelagent/sessions"));
unsafe {
std::env::remove_var("XDG_STATE_HOME");
}
}
#[test]
fn sweep_stale_sockets_removes_dead_pid_socket() {
let dir = tempfile::tempdir().unwrap();
let dead_pid = i32::MAX - 1;
let dead_sock = dir.path().join(format!("{dead_pid}.sock"));
fs::write(&dead_sock, "").unwrap();
let preserved = dir.path().join("manual.sock");
fs::write(&preserved, "").unwrap();
let live_sock = dir.path().join(format!("{}.sock", process::id()));
fs::write(&live_sock, "").unwrap();
sweep_stale_sockets(dir.path());
assert!(!dead_sock.exists(), "dead-PID socket should be removed");
assert!(preserved.exists(), "non-PID-named file must be preserved");
assert!(live_sock.exists(), "live-PID socket must be preserved");
}
#[test]
fn sweep_stale_sockets_ignores_missing_directory() {
sweep_stale_sockets(Path::new("/tmp/this/does/not/exist/for/sure/xyz"));
}
#[test]
fn spawn_and_connect_roundtrip() {
use std::io::{Read, Write};
use std::os::unix::net::UnixStream;
use std::sync::mpsc;
let dir = tempfile::tempdir().unwrap();
let sock_path = dir.path().join("rt.sock");
let (tx, rx) = mpsc::channel::<McpCommand>();
std::thread::spawn(move || {
while let Ok(cmd) = rx.recv() {
let _ = cmd.reply.send(r#"{"ok":true}"#.to_string());
}
});
let runtime = crate::test_support::runtime_handle();
let hub = crate::mcp_bridge::McpHub::start(&runtime);
let _guard = spawn_mcp_socket_server(sock_path.clone(), tx, runtime, &hub).expect("spawn");
let start = std::time::Instant::now();
while !sock_path.exists() {
if start.elapsed() > std::time::Duration::from_secs(2) {
panic!("socket never appeared at {sock_path:?}");
}
std::thread::sleep(std::time::Duration::from_millis(10));
}
let mut stream = UnixStream::connect(&sock_path).expect("connect");
let init = br#"{"jsonrpc":"2.0","id":1,"method":"initialize","params":{"protocolVersion":"2025-06-18","capabilities":{},"clientInfo":{"name":"test","version":"0"}}}
"#;
stream.write_all(init).unwrap();
stream
.set_read_timeout(Some(std::time::Duration::from_secs(2)))
.unwrap();
let mut buf = [0u8; 4096];
let n = stream.read(&mut buf).expect("read initialize reply");
let got = std::str::from_utf8(&buf[..n]).unwrap();
assert!(got.contains("\"id\":1"), "no id in reply: {got}");
assert!(
got.contains("result") || got.contains("error"),
"no result/error in reply: {got}"
);
let mut stream2 = UnixStream::connect(&sock_path).expect("second connect");
stream2.write_all(init).unwrap();
stream2
.set_read_timeout(Some(std::time::Duration::from_secs(2)))
.unwrap();
let n = stream2.read(&mut buf).expect("read reply on 2nd conn");
assert!(n > 0);
}
#[test]
fn socket_guard_unlinks_on_drop() {
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("guard.sock");
fs::write(&path, "").unwrap();
assert!(path.exists());
{
let _guard = SocketGuard { path: path.clone() };
}
assert!(!path.exists(), "SocketGuard::drop should unlink");
}
#[test]
fn spawn_mcp_socket_does_not_chmod_user_supplied_parent() {
use std::os::unix::fs::PermissionsExt;
use std::sync::mpsc;
let _lock = env_lock();
unsafe {
std::env::set_var("XDG_STATE_HOME", "/tmp/xdg-state-trv-chmod-test");
std::env::set_var("HOME", "/tmp/home-trv-chmod-test");
}
let dir = tempfile::tempdir().unwrap();
let parent = dir.path();
fs::set_permissions(parent, fs::Permissions::from_mode(0o755)).unwrap();
let before_mode = fs::metadata(parent).unwrap().permissions().mode() & 0o777;
assert_eq!(before_mode, 0o755, "precondition");
let sock_path = parent.join("user.sock");
let (tx, _rx) = mpsc::channel::<McpCommand>();
let runtime = crate::test_support::runtime_handle();
let hub = crate::mcp_bridge::McpHub::start(&runtime);
let _guard = spawn_mcp_socket_server(sock_path.clone(), tx, runtime, &hub)
.expect("bind at user-supplied path succeeds");
let start = std::time::Instant::now();
while !sock_path.exists() {
if start.elapsed() > std::time::Duration::from_secs(2) {
panic!("socket never appeared at {sock_path:?}");
}
std::thread::sleep(std::time::Duration::from_millis(10));
}
let after_mode = fs::metadata(parent).unwrap().permissions().mode() & 0o777;
assert_eq!(
after_mode, 0o755,
"user-supplied parent perms must NOT be tightened to 0700; was {after_mode:o}"
);
unsafe {
std::env::remove_var("XDG_STATE_HOME");
std::env::remove_var("HOME");
}
}
}