use std::collections::HashMap;
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::Duration;
use futures::future::BoxFuture;
use serde_json::Value;
use tokio::sync::{Mutex, broadcast, oneshot};
use super::super::types::{
JsonRpcNotification, JsonRpcRequest, JsonRpcResponse, MCP_PROTOCOL_VERSION,
NotificationReceiver,
};
use echo_core::error::{McpError, ReactError, Result};
use super::McpTransport;
pub struct HttpTransport {
client: reqwest::Client,
endpoint: String,
headers: HashMap<String, String>,
next_id: Arc<AtomicU64>,
session_id: Arc<Mutex<Option<String>>>,
notification_tx: broadcast::Sender<JsonRpcNotification>,
pending: Arc<Mutex<HashMap<u64, oneshot::Sender<JsonRpcResponse>>>>,
}
impl HttpTransport {
pub fn new(endpoint: String, headers: HashMap<String, String>) -> Self {
let (notification_tx, _) = broadcast::channel(256);
Self {
client: reqwest::Client::builder()
.timeout(std::time::Duration::from_secs(60))
.build()
.unwrap_or_else(|_| {
reqwest::Client::builder()
.timeout(std::time::Duration::from_secs(60))
.build()
.unwrap_or_default()
}),
endpoint: endpoint.trim_end_matches('/').to_string(),
headers,
next_id: Arc::new(AtomicU64::new(1)),
session_id: Arc::new(Mutex::new(None)),
notification_tx,
pending: Arc::new(Mutex::new(HashMap::new())),
}
}
}
impl McpTransport for HttpTransport {
fn send(&self, request: JsonRpcRequest) -> BoxFuture<'_, Result<JsonRpcResponse>> {
Box::pin(async move {
let mut request = request;
let id = self.next_id.fetch_add(1, Ordering::SeqCst);
request.id = Some(Value::Number(id.into()));
let (tx, rx) = oneshot::channel();
{
let mut pending = self.pending.lock().await;
pending.insert(id, tx);
}
let mut builder = self
.client
.post(&self.endpoint)
.header("Content-Type", "application/json")
.header("MCP-Protocol-Version", MCP_PROTOCOL_VERSION)
.header("Accept", "application/json, text/event-stream")
.json(&request);
{
let sid = self.session_id.lock().await;
if let Some(ref session_id) = *sid {
builder = builder.header("Mcp-Session-Id", session_id.as_str());
}
}
for (k, v) in &self.headers {
builder = builder.header(k, v);
}
const MAX_RETRIES: u32 = 3;
const BASE_DELAY_MS: u64 = 500;
let mut retry_count: u32 = 0;
let response = loop {
let req = match builder.try_clone() {
Some(r) => r,
None => {
self.pending.lock().await.remove(&id);
return Err(ReactError::Mcp(McpError::ConnectionFailed(
"无法复制 HTTP 请求".to_string(),
)));
}
};
match req.send().await {
Ok(resp) => break resp,
Err(e) => {
if retry_count < MAX_RETRIES && is_retryable_error(&e) {
retry_count += 1;
let delay_ms = BASE_DELAY_MS * 2u64.pow(retry_count - 1);
let jitter = delay_ms / 4;
let delay = if jitter > 0 {
let r = (delay_ms as i64 - jitter as i64
+ (std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.subsec_nanos()
as i64
% (jitter * 2) as i64))
.max(0) as u64;
Duration::from_millis(r)
} else {
Duration::from_millis(delay_ms)
};
tracing::warn!(
attempt = retry_count,
max = MAX_RETRIES,
delay_ms = delay.as_millis() as u64,
error = %e,
"MCP HTTP 请求失败,重试中..."
);
tokio::time::sleep(delay).await;
} else {
self.pending.lock().await.remove(&id);
return Err(ReactError::Mcp(McpError::ConnectionFailed(format!(
"HTTP 请求失败: {}",
e
))));
}
}
}
};
if let Some(new_session_id) = response.headers().get("mcp-session-id")
&& let Ok(sid) = new_session_id.to_str()
{
let mut sid_guard = self.session_id.lock().await;
*sid_guard = Some(sid.to_string());
tracing::debug!("HTTP: 保存 Mcp-Session-Id: {}", sid);
}
let status = response.status().as_u16();
if status == 202 {
tracing::debug!("HTTP: 收到 202 Accepted (id={}),等待异步响应", id);
let result =
match tokio::time::timeout(std::time::Duration::from_secs(60), rx).await {
Ok(Ok(response)) => Ok(response),
Ok(Err(_)) => Err(ReactError::Mcp(McpError::TransportClosed)),
Err(_) => {
self.pending.lock().await.remove(&id);
Err(ReactError::Mcp(McpError::ProtocolError(format!(
"等待 HTTP 异步响应超时 (id={})",
id
))))
}
}?;
return Ok(result);
}
self.pending.lock().await.remove(&id);
if !response.status().is_success() {
let body = response.text().await.unwrap_or_default();
return Err(ReactError::Mcp(McpError::ConnectionFailed(format!(
"HTTP 错误 {}: {}",
status, body
))));
}
let rpc_response: JsonRpcResponse = response.json().await.map_err(|e| {
ReactError::Mcp(McpError::ProtocolError(format!(
"解析 HTTP 响应失败: {}",
e
)))
})?;
Ok(rpc_response)
})
}
fn notify(&self, notification: JsonRpcNotification) -> BoxFuture<'_, Result<()>> {
Box::pin(async move {
let mut builder = self
.client
.post(&self.endpoint)
.header("Content-Type", "application/json")
.header("MCP-Protocol-Version", MCP_PROTOCOL_VERSION)
.json(¬ification);
{
let sid = self.session_id.lock().await;
if let Some(ref session_id) = *sid {
builder = builder.header("Mcp-Session-Id", session_id.as_str());
}
}
for (k, v) in &self.headers {
builder = builder.header(k, v);
}
let _ = builder.send().await;
Ok(())
})
}
fn close(&self) -> BoxFuture<'_, ()> {
Box::pin(async move {
let notification = JsonRpcNotification::new("notifications/cancelled", None);
let _ = self.notify(notification).await;
self.pending.lock().await.clear();
})
}
fn notification_rx(&self) -> Option<Arc<dyn super::super::types::JsonRpcNotificationReceiver>> {
Some(Arc::new(NotificationReceiver::new(
self.notification_tx.subscribe(),
)))
}
}
fn is_retryable_error(e: &reqwest::Error) -> bool {
if e.is_timeout() || e.is_connect() {
return true;
}
if let Some(status) = e.status() {
let code = status.as_u16();
return code == 502 || code == 503 || code == 504;
}
let msg = e.to_string().to_lowercase();
msg.contains("connection reset")
|| msg.contains("connection refused")
|| msg.contains("broken pipe")
|| msg.contains("eof")
|| msg.contains("unexpected eof")
|| msg.contains("dns")
|| msg.contains("tls")
|| msg.contains("timed out")
}