fastmcp 0.0.0

A Rust framework for building Model Context Protocol (MCP) services
Documentation
use std::fmt;
use std::net::SocketAddr;
use std::sync::Arc;

use async_trait::async_trait;
use axum::extract::State;
use axum::http::StatusCode;
use axum::response::IntoResponse;
use axum::routing::post;
use axum::{Json, Router};
use serde_json::json;
use tokio::sync::mpsc;
use tracing::{debug, error, info};

use super::{RequestHandler, Transport, TransportMessage};
use crate::error::{Error, Result};
use crate::protocol::Request;

/// HTTP 传输层状态
struct HttpState {
    /// 请求处理器
    request_handler: RequestHandler,
}

/// HTTP 传输实现
pub struct HttpTransport {
    /// 服务器主机地址
    host: String,

    /// 服务器端口
    port: u16,

    /// 请求处理器
    request_handler: Option<RequestHandler>,
}

// 手动实现 Debug,因为 RequestHandler 没有实现 Debug
impl fmt::Debug for HttpTransport {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        f.debug_struct("HttpTransport")
            .field("host", &self.host)
            .field("port", &self.port)
            .field(
                "request_handler",
                &format!("<handler: {}>", self.request_handler.is_some()),
            )
            .finish()
    }
}

impl HttpTransport {
    /// 创建新的 HTTP 传输
    pub fn new(host: String, port: u16) -> Self {
        Self {
            host,
            port,
            request_handler: None,
        }
    }
}

#[cfg(feature = "http")]
#[async_trait]
impl Transport for HttpTransport {
    async fn start(&self) -> Result<()> {
        info!("正在启动 HTTP 传输,地址 {}:{}", self.host, self.port);

        // 确保请求处理器已设置
        let request_handler = self
            .request_handler
            .clone()
            .ok_or_else(|| Error::Transport("请求处理器未设置".to_string()))?;

        // 创建 HTTP 路由
        let app_state = Arc::new(HttpState { request_handler });

        let app = Router::new()
            .route("/mcp", post(handle_mcp_request))
            .with_state(app_state);

        // 构建服务器地址
        let addr: SocketAddr = format!("{}:{}", self.host, self.port)
            .parse()
            .map_err(|e| Error::Transport(format!("无效的服务器地址: {e}")))?;

        // 启动服务器
        info!("HTTP 服务器监听于 {}", addr);
        axum::Server::bind(&addr)
            .serve(app.into_make_service())
            .await
            .map_err(|e| Error::Transport(format!("HTTP 服务器错误: {e}")))?;

        Ok(())
    }

    fn set_request_handler(&mut self, handler: RequestHandler) {
        self.request_handler = Some(handler);
    }
}

/// 处理 MCP 请求
async fn handle_mcp_request(
    State(state): State<Arc<HttpState>>, Json(request): Json<Request>,
) -> impl IntoResponse {
    debug!("收到 HTTP 请求,工具: {}", request.tool);

    // 创建响应通道
    let (resp_tx, mut resp_rx) = mpsc::channel(1);

    // 使用请求处理器处理请求
    let tx = (state.request_handler)(request.clone());

    // 发送请求到处理器
    if let Err(e) = tx.send((request, resp_tx)).await {
        error!("发送请求失败: {}", e);
        return (
            StatusCode::INTERNAL_SERVER_ERROR,
            Json(json!({
                "error": format!("处理请求失败: {}", e)
            })),
        );
    }

    // 等待响应
    match resp_rx.recv().await {
        Some(TransportMessage::Response(response)) => (StatusCode::OK, Json(json!(response))),
        Some(TransportMessage::Error(error)) => {
            let status = match error.code {
                400..=499 => StatusCode::BAD_REQUEST,
                500..=599 => StatusCode::INTERNAL_SERVER_ERROR,
                _ => StatusCode::INTERNAL_SERVER_ERROR,
            };
            (status, Json(json!(error)))
        }
        Some(TransportMessage::Notification(_)) => {
            (
                StatusCode::BAD_REQUEST,
                Json(json!({
                    "error": "收到通知而非响应"
                })),
            )
        }
        None => {
            (
                StatusCode::INTERNAL_SERVER_ERROR,
                Json(json!({
                    "error": "未收到响应"
                })),
            )
        }
    }
}