use std::collections::HashMap;
use std::sync::Arc;
use async_trait::async_trait;
use tokio::io::{AsyncBufRead, AsyncBufReadExt, AsyncWrite, AsyncWriteExt};
use tokio::sync::{Mutex, oneshot};
use tokio::task::JoinHandle;
use crate::tools::mcp::protocol::{McpRequest, McpResponse};
use crate::tools::tool::ToolError;
#[async_trait]
pub trait McpTransport: Send + Sync {
async fn send(
&self,
request: &McpRequest,
headers: &HashMap<String, String>,
) -> Result<McpResponse, ToolError>;
async fn shutdown(&self) -> Result<(), ToolError>;
fn supports_http_features(&self) -> bool {
false
}
}
pub async fn write_jsonrpc_line(
writer: &mut (impl AsyncWrite + Unpin),
request: &McpRequest,
) -> Result<(), ToolError> {
let json = serde_json::to_string(request).map_err(|e| {
ToolError::ExternalService(format!("Failed to serialize JSON-RPC request: {e}"))
})?;
writer.write_all(json.as_bytes()).await.map_err(|e| {
ToolError::ExternalService(format!("Failed to write JSON-RPC request: {e}"))
})?;
writer
.write_all(b"\n")
.await
.map_err(|e| ToolError::ExternalService(format!("Failed to write newline: {e}")))?;
writer
.flush()
.await
.map_err(|e| ToolError::ExternalService(format!("Failed to flush JSON-RPC writer: {e}")))?;
Ok(())
}
pub fn spawn_jsonrpc_reader<R: AsyncBufRead + Unpin + Send + 'static>(
reader: R,
pending: Arc<Mutex<HashMap<u64, oneshot::Sender<McpResponse>>>>,
server_name: String,
) -> JoinHandle<()> {
tokio::spawn(async move {
let mut lines = reader.lines();
while let Ok(Some(line)) = lines.next_line().await {
let response = match serde_json::from_str::<McpResponse>(&line) {
Ok(resp) => resp,
Err(e) => {
let preview: String = line.chars().take(200).collect();
tracing::debug!(
"[{}] Failed to parse JSON-RPC response: {} — line: {}{}",
server_name,
e,
preview,
if line.len() > 200 { "…" } else { "" }
);
continue;
}
};
let Some(id) = response.id else {
tracing::debug!(
"[{}] Received JSON-RPC notification (no id), skipping dispatch",
server_name
);
continue;
};
let mut map = pending.lock().await;
if let Some(tx) = map.remove(&id) {
let _ = tx.send(response);
} else {
tracing::debug!(
"[{}] Received response for unknown request id {}",
server_name,
id
);
}
}
tracing::debug!("[{}] JSON-RPC reader finished", server_name);
})
}
pub(crate) async fn stream_transport_send<W: AsyncWrite + Unpin>(
writer: &Mutex<W>,
pending: &Mutex<HashMap<u64, oneshot::Sender<McpResponse>>>,
request: &McpRequest,
server_name: &str,
timeout_duration: std::time::Duration,
) -> Result<McpResponse, ToolError> {
if request.id.is_none() {
let mut w = writer.lock().await;
write_jsonrpc_line(&mut *w, request).await?;
return Ok(McpResponse {
jsonrpc: "2.0".to_string(),
id: None,
result: None,
error: None,
});
}
let id = request.id.unwrap_or(0);
let (tx, rx) = oneshot::channel();
{
let mut map = pending.lock().await;
map.insert(id, tx);
}
{
let mut w = writer.lock().await;
if let Err(e) = write_jsonrpc_line(&mut *w, request).await {
let mut map = pending.lock().await;
map.remove(&id);
return Err(e);
}
}
match tokio::time::timeout(timeout_duration, rx).await {
Ok(Ok(response)) => Ok(response),
Ok(Err(_)) => {
let mut map = pending.lock().await;
map.remove(&id);
Err(ToolError::ExternalService(format!(
"[{}] MCP server closed connection before responding to request {:?}",
server_name, request.id
)))
}
Err(_) => {
let mut map = pending.lock().await;
map.remove(&id);
Err(ToolError::ExternalService(format!(
"[{}] Timeout waiting for response to request {:?} after {:?}",
server_name, request.id, timeout_duration
)))
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_write_jsonrpc_line_serializes_and_flushes() {
let request = McpRequest {
jsonrpc: "2.0".into(),
id: Some(1),
method: "test/method".into(),
params: None,
};
let mut buf = Vec::new();
write_jsonrpc_line(&mut buf, &request)
.await
.expect("write should succeed");
let written = String::from_utf8(buf).expect("should be valid UTF-8");
assert!(written.ends_with('\n'));
let parsed: serde_json::Value =
serde_json::from_str(written.trim()).expect("should be valid JSON");
assert_eq!(parsed["id"], 1);
assert_eq!(parsed["method"], "test/method");
}
#[tokio::test]
async fn test_spawn_jsonrpc_reader_dispatches_response() {
let response = McpResponse {
jsonrpc: "2.0".into(),
id: Some(42),
result: Some(serde_json::json!({"tools": []})),
error: None,
};
let line = format!("{}\n", serde_json::to_string(&response).unwrap());
let reader = std::io::Cursor::new(line.into_bytes());
let pending: Arc<Mutex<HashMap<u64, oneshot::Sender<McpResponse>>>> =
Arc::new(Mutex::new(HashMap::new()));
let (tx, rx) = oneshot::channel();
{
let mut map = pending.lock().await;
map.insert(42, tx);
}
let handle = spawn_jsonrpc_reader(reader, pending.clone(), "test".into());
let resp = rx.await.expect("should receive response");
assert_eq!(resp.id, Some(42));
assert!(resp.result.is_some());
handle.await.expect("reader task should finish");
}
#[tokio::test]
async fn test_spawn_jsonrpc_reader_skips_invalid_lines() {
let input = "this is not json\n{\"jsonrpc\":\"2.0\",\"id\":7,\"result\":null}\n";
let reader = std::io::Cursor::new(input.as_bytes().to_vec());
let pending: Arc<Mutex<HashMap<u64, oneshot::Sender<McpResponse>>>> =
Arc::new(Mutex::new(HashMap::new()));
let (tx, rx) = oneshot::channel();
{
let mut map = pending.lock().await;
map.insert(7, tx);
}
let handle = spawn_jsonrpc_reader(reader, pending.clone(), "test".into());
let resp = rx
.await
.expect("should receive response despite earlier invalid line");
assert_eq!(resp.id, Some(7));
handle.await.expect("reader task should finish");
}
#[tokio::test]
async fn test_notification_does_not_resolve_pending_id_zero() {
let notification = r#"{"jsonrpc":"2.0","method":"notifications/progress","params":{}}"#;
let real_response = r#"{"jsonrpc":"2.0","id":0,"result":{"ok":true}}"#;
let input = format!("{notification}\n{real_response}\n");
let reader = std::io::Cursor::new(input.into_bytes());
let pending: Arc<Mutex<HashMap<u64, oneshot::Sender<McpResponse>>>> =
Arc::new(Mutex::new(HashMap::new()));
let (tx, rx) = oneshot::channel();
{
let mut map = pending.lock().await;
map.insert(0, tx);
}
let handle = spawn_jsonrpc_reader(reader, pending.clone(), "test".into());
let resp = rx.await.expect("should receive the real id=0 response");
assert_eq!(resp.id, Some(0));
assert!(resp.result.is_some());
handle.await.expect("reader task should finish");
}
}