use std::collections::HashMap;
use std::path::{Path, PathBuf};
use std::sync::Arc;
use std::time::Duration;
use async_trait::async_trait;
use tokio::io::BufReader;
use tokio::net::UnixStream;
use tokio::sync::{Mutex, oneshot};
use tokio::task::JoinHandle;
use crate::tools::mcp::protocol::{McpRequest, McpResponse};
use crate::tools::mcp::transport::{McpTransport, spawn_jsonrpc_reader, stream_transport_send};
use crate::tools::tool::ToolError;
pub struct UnixMcpTransport {
socket_path: PathBuf,
server_name: String,
writer: Arc<Mutex<tokio::io::WriteHalf<UnixStream>>>,
pending: Arc<Mutex<HashMap<u64, oneshot::Sender<McpResponse>>>>,
reader_handle: Mutex<Option<JoinHandle<()>>>,
}
impl UnixMcpTransport {
pub async fn connect(
name: impl Into<String>,
socket_path: impl AsRef<Path>,
) -> Result<Self, ToolError> {
let server_name = name.into();
let socket_path = socket_path.as_ref().to_path_buf();
let stream = UnixStream::connect(&socket_path).await.map_err(|e| {
ToolError::ExternalService(format!(
"[{}] Failed to connect to Unix socket '{}': {}",
server_name,
socket_path.display(),
e
))
})?;
let (read_half, write_half) = tokio::io::split(stream);
let pending: Arc<Mutex<HashMap<u64, oneshot::Sender<McpResponse>>>> =
Arc::new(Mutex::new(HashMap::new()));
let reader = BufReader::new(read_half);
let reader_handle = spawn_jsonrpc_reader(reader, pending.clone(), server_name.clone());
Ok(Self {
socket_path,
server_name,
writer: Arc::new(Mutex::new(write_half)),
pending,
reader_handle: Mutex::new(Some(reader_handle)),
})
}
#[cfg(test)]
pub(crate) fn socket_path(&self) -> &Path {
&self.socket_path
}
#[cfg(test)]
pub(crate) fn server_name(&self) -> &str {
&self.server_name
}
}
#[async_trait]
impl McpTransport for UnixMcpTransport {
async fn send(
&self,
request: &McpRequest,
_headers: &HashMap<String, String>,
) -> Result<McpResponse, ToolError> {
stream_transport_send(
&self.writer,
&self.pending,
request,
&self.server_name,
Duration::from_secs(30),
)
.await
}
async fn shutdown(&self) -> Result<(), ToolError> {
if let Some(handle) = self.reader_handle.lock().await.take() {
handle.abort();
}
{
let mut pending = self.pending.lock().await;
pending.clear(); }
tracing::debug!(
"[{}] Unix transport shut down (socket: {})",
self.server_name,
self.socket_path.display()
);
Ok(())
}
fn supports_http_features(&self) -> bool {
false
}
}
#[cfg(test)]
mod tests {
use super::*;
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader as TokioBufReader};
use tokio::net::UnixListener;
#[tokio::test]
async fn test_connect_nonexistent_socket_fails() {
let tmp_dir = tempfile::tempdir().expect("create temp dir");
let socket_path = tmp_dir.path().join("nonexistent.sock");
let result = UnixMcpTransport::connect("test", &socket_path).await;
let err = result.err().expect("should be an error").to_string();
assert!(
err.contains("Failed to connect"),
"Error should mention connection failure: {}",
err
);
}
#[tokio::test]
async fn test_round_trip_via_unix_socket() {
let tmp_dir = tempfile::tempdir().expect("create temp dir");
let socket_path = tmp_dir.path().join("test.sock");
let listener = UnixListener::bind(&socket_path).expect("bind listener");
let handler = tokio::spawn(async move {
let (stream, _) = listener.accept().await.expect("accept connection");
let (read_half, mut write_half) = tokio::io::split(stream);
let mut reader = TokioBufReader::new(read_half);
let mut line = String::new();
reader
.read_line(&mut line)
.await
.expect("read request line");
let req: McpRequest = serde_json::from_str(&line).expect("parse request");
let response = McpResponse {
jsonrpc: "2.0".to_string(),
id: req.id,
result: Some(serde_json::json!({"tools": []})),
error: None,
};
let mut resp_bytes = serde_json::to_vec(&response).expect("serialize response");
resp_bytes.push(b'\n');
write_half
.write_all(&resp_bytes)
.await
.expect("write response");
write_half.flush().await.expect("flush");
});
let transport = UnixMcpTransport::connect("test-uds", &socket_path)
.await
.expect("connect should succeed");
assert_eq!(transport.socket_path(), socket_path.as_path());
assert_eq!(transport.server_name(), "test-uds");
let request = McpRequest::list_tools(42);
let headers = HashMap::new();
let response = transport.send(&request, &headers).await.expect("send");
assert_eq!(response.id, Some(42));
assert!(response.result.is_some());
assert!(response.error.is_none());
transport.shutdown().await.expect("shutdown");
handler.await.expect("handler task");
}
#[tokio::test]
async fn test_shutdown_is_idempotent() {
let tmp_dir = tempfile::tempdir().expect("create temp dir");
let socket_path = tmp_dir.path().join("idle.sock");
let listener = UnixListener::bind(&socket_path).expect("bind listener");
let _handler = tokio::spawn(async move {
let _stream = listener.accept().await;
});
let transport = UnixMcpTransport::connect("test-idle", &socket_path)
.await
.expect("connect");
transport.shutdown().await.expect("first shutdown");
transport.shutdown().await.expect("second shutdown");
}
}