use std::fmt;
use std::net::SocketAddr;
use std::sync::Arc;
use async_trait::async_trait;
use axum::extract::State;
use axum::http::StatusCode;
use axum::response::IntoResponse;
use axum::routing::post;
use axum::{Json, Router};
use serde_json::json;
use tokio::sync::mpsc;
use tracing::{debug, error, info};
use super::{RequestHandler, Transport, TransportMessage};
use crate::error::{Error, Result};
use crate::protocol::Request;
struct HttpState {
request_handler: RequestHandler,
}
pub struct HttpTransport {
host: String,
port: u16,
request_handler: Option<RequestHandler>,
}
impl fmt::Debug for HttpTransport {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("HttpTransport")
.field("host", &self.host)
.field("port", &self.port)
.field(
"request_handler",
&format!("<handler: {}>", self.request_handler.is_some()),
)
.finish()
}
}
impl HttpTransport {
pub fn new(host: String, port: u16) -> Self {
Self {
host,
port,
request_handler: None,
}
}
}
#[cfg(feature = "http")]
#[async_trait]
impl Transport for HttpTransport {
async fn start(&self) -> Result<()> {
info!("正在启动 HTTP 传输,地址 {}:{}", self.host, self.port);
let request_handler = self
.request_handler
.clone()
.ok_or_else(|| Error::Transport("请求处理器未设置".to_string()))?;
let app_state = Arc::new(HttpState { request_handler });
let app = Router::new()
.route("/mcp", post(handle_mcp_request))
.with_state(app_state);
let addr: SocketAddr = format!("{}:{}", self.host, self.port)
.parse()
.map_err(|e| Error::Transport(format!("无效的服务器地址: {e}")))?;
info!("HTTP 服务器监听于 {}", addr);
axum::Server::bind(&addr)
.serve(app.into_make_service())
.await
.map_err(|e| Error::Transport(format!("HTTP 服务器错误: {e}")))?;
Ok(())
}
fn set_request_handler(&mut self, handler: RequestHandler) {
self.request_handler = Some(handler);
}
}
async fn handle_mcp_request(
State(state): State<Arc<HttpState>>, Json(request): Json<Request>,
) -> impl IntoResponse {
debug!("收到 HTTP 请求,工具: {}", request.tool);
let (resp_tx, mut resp_rx) = mpsc::channel(1);
let tx = (state.request_handler)(request.clone());
if let Err(e) = tx.send((request, resp_tx)).await {
error!("发送请求失败: {}", e);
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({
"error": format!("处理请求失败: {}", e)
})),
);
}
match resp_rx.recv().await {
Some(TransportMessage::Response(response)) => (StatusCode::OK, Json(json!(response))),
Some(TransportMessage::Error(error)) => {
let status = match error.code {
400..=499 => StatusCode::BAD_REQUEST,
500..=599 => StatusCode::INTERNAL_SERVER_ERROR,
_ => StatusCode::INTERNAL_SERVER_ERROR,
};
(status, Json(json!(error)))
}
Some(TransportMessage::Notification(_)) => {
(
StatusCode::BAD_REQUEST,
Json(json!({
"error": "收到通知而非响应"
})),
)
}
None => {
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({
"error": "未收到响应"
})),
)
}
}
}