fastmcp 0.0.0

A Rust framework for building Model Context Protocol (MCP) services
Documentation
use std::fmt;
use std::net::SocketAddr;
use std::sync::Arc;

use async_trait::async_trait;
use futures::{SinkExt, StreamExt};
use serde_json::json;
use tokio::net::{TcpListener, TcpStream};
use tokio::sync::{RwLock, mpsc};
use tokio_tungstenite::accept_async;
use tokio_tungstenite::tungstenite::protocol::Message;
use tracing::{debug, error, info, warn};

use super::{RequestHandler, Transport, TransportMessage};
use crate::error::{Error, Result};
use crate::protocol::Request;

/// 活动连接状态
type Connections = Arc<RwLock<Vec<mpsc::Sender<Message>>>>;

/// WebSocket传输实现
pub struct WebSocketTransport {
    /// 主机地址
    host: String,

    /// 监听端口
    port: u16,

    /// 请求处理器
    request_handler: Option<RequestHandler>,

    /// 活动连接列表
    connections: Connections,
}

// 因为RequestHandler没有实现Debug,手动实现Debug
impl fmt::Debug for WebSocketTransport {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        f.debug_struct("WebSocketTransport")
            .field("host", &self.host)
            .field("port", &self.port)
            .field(
                "request_handler",
                &format!("<handler: {}>", self.request_handler.is_some()),
            )
            .field("connections", &"<connections>")
            .finish()
    }
}

impl WebSocketTransport {
    /// 创建新的WebSocket传输
    pub fn new(host: String, port: u16) -> Self {
        Self {
            host,
            port,
            request_handler: None,
            connections: Arc::new(RwLock::new(Vec::new())),
        }
    }

    /// 处理单个WebSocket连接
    async fn handle_connection(
        websocket: tokio_tungstenite::WebSocketStream<TcpStream>, addr: SocketAddr,
        request_handler: RequestHandler, connections: Connections,
    ) {
        info!("新的WebSocket连接: {}", addr);

        // 分离WebSocket流为发送和接收部分
        let (mut ws_sender, mut ws_receiver) = websocket.split();

        // 为每个连接创建一个消息通道
        let (msg_tx, mut msg_rx) = mpsc::channel::<Message>(32);

        // 添加连接到活动连接列表
        {
            let mut connections = connections.write().await;
            connections.push(msg_tx.clone());
        }

        // 创建用于转发响应的任务
        let forward_task = tokio::spawn(async move {
            while let Some(msg) = msg_rx.recv().await {
                if let Err(e) = ws_sender.send(msg).await {
                    error!("发送WebSocket消息失败: {}", e);
                    break;
                }
            }
        });

        // 处理接收到的消息
        while let Some(result) = ws_receiver.next().await {
            match result {
                Ok(msg) => {
                    if msg.is_text() || msg.is_binary() {
                        let text = msg.to_text().unwrap_or_default();

                        // 尝试解析为请求
                        match serde_json::from_str::<Request>(text) {
                            Ok(request) => {
                                debug!("收到WebSocket请求: {}", request.tool);

                                // 创建响应channel
                                let (resp_tx, mut resp_rx) = mpsc::channel(10);

                                // 使用请求处理器
                                let tx = (request_handler)(request.clone());

                                // 处理请求并发送响应
                                let msg_tx_clone = msg_tx.clone();
                                tokio::spawn(async move {
                                    if let Err(e) = tx.send((request, resp_tx)).await {
                                        error!("发送请求到处理器失败: {}", e);
                                        return;
                                    }

                                    // 处理响应
                                    while let Some(message) = resp_rx.recv().await {
                                        match message {
                                            TransportMessage::Response(response) => {
                                                if let Ok(json) = serde_json::to_string(&response) {
                                                    if let Err(e) =
                                                        msg_tx_clone.send(Message::Text(json)).await
                                                    {
                                                        error!("发送响应失败: {}", e);
                                                    }
                                                }
                                            }
                                            TransportMessage::Error(error) => {
                                                if let Ok(json) = serde_json::to_string(&error) {
                                                    if let Err(e) =
                                                        msg_tx_clone.send(Message::Text(json)).await
                                                    {
                                                        error!("发送错误响应失败: {}", e);
                                                    }
                                                }
                                            }
                                            TransportMessage::Notification(notification) => {
                                                if let Ok(json) =
                                                    serde_json::to_string(&notification)
                                                {
                                                    if let Err(e) =
                                                        msg_tx_clone.send(Message::Text(json)).await
                                                    {
                                                        error!("发送通知失败: {}", e);
                                                    }
                                                }
                                            }
                                        }
                                    }
                                });
                            }
                            Err(e) => {
                                warn!("无法解析请求: {}", e);
                                let error_msg = json!({
                                    "error": "无效的请求格式",
                                    "details": e.to_string()
                                });

                                if let Ok(json) = serde_json::to_string(&error_msg) {
                                    if let Err(e) = msg_tx.send(Message::Text(json)).await {
                                        error!("发送错误消息失败: {}", e);
                                    }
                                }
                            }
                        }
                    } else if msg.is_close() {
                        break;
                    }
                }
                Err(e) => {
                    error!("WebSocket错误: {}", e);
                    break;
                }
            }
        }

        // 连接关闭,清理
        forward_task.abort();
        info!("WebSocket连接关闭: {}", addr);

        // 从活动连接列表中移除
        {
            let mut connections = connections.write().await;
            connections.retain(|tx| !tx.is_closed());
        }
    }
}

#[cfg(feature = "websocket")]
#[async_trait]
impl Transport for WebSocketTransport {
    async fn start(&self) -> Result<()> {
        info!("启动WebSocket传输,地址 {}:{}", self.host, self.port);

        // 确保请求处理器设置
        let request_handler = self
            .request_handler
            .clone()
            .ok_or_else(|| Error::Transport("请求处理器未设置".to_string()))?;

        // 绑定TCP监听器
        let addr = format!("{}:{}", self.host, self.port);
        let listener = TcpListener::bind(&addr)
            .await
            .map_err(|e| Error::Transport(format!("绑定WebSocket地址失败: {e}")))?;

        info!("WebSocket服务器监听于 {}", addr);

        // 接收连接
        while let Ok((stream, addr)) = listener.accept().await {
            let request_handler = request_handler.clone();
            let connections = self.connections.clone();

            // 处理WebSocket握手和连接
            tokio::spawn(async move {
                match accept_async(stream).await {
                    Ok(ws_stream) => {
                        Self::handle_connection(ws_stream, addr, request_handler, connections)
                            .await;
                    }
                    Err(e) => {
                        error!("WebSocket握手失败: {}", e);
                    }
                }
            });
        }

        Ok(())
    }

    fn set_request_handler(&mut self, handler: RequestHandler) {
        self.request_handler = Some(handler);
    }
}