Skip to main content

test_mcp_server/
test_mcp_server.rs

1//! 用于集成测试的最小 MCP 服务器
2//!
3//! 支持 echo 和 counter 两个简单工具,用于测试连接、重连和工具调用功能。
4//!
5//! 运行方式:
6//! ```bash
7//! cargo run --example test_mcp_server -p mcp-sse-proxy
8//! ```
9
10use rmcp::{
11    ErrorData, RoleServer, ServerHandler, ServiceExt,
12    model::{
13        CallToolRequestParam, CallToolResult, Content, Implementation, JsonObject, ListToolsResult,
14        PaginatedRequestParam, ProtocolVersion, ServerCapabilities, ServerInfo, Tool,
15    },
16    service::RequestContext,
17    transport::stdio,
18};
19use std::sync::Arc;
20use tokio::sync::Mutex;
21
22/// 测试用 MCP 服务器
23///
24/// 提供简单工具:
25/// - `echo`: 回显输入消息
26/// - `increment`: 递增计数器并返回当前值
27/// - `reset`: 重置计数器
28/// - `get_counter`: 获取当前计数器值
29#[derive(Clone)]
30pub struct TestMcpServer {
31    counter: Arc<Mutex<i32>>,
32}
33
34impl TestMcpServer {
35    /// 创建新的测试服务器实例
36    pub fn new() -> Self {
37        Self {
38            counter: Arc::new(Mutex::new(0)),
39        }
40    }
41
42    /// 创建一个简单的空 schema
43    fn empty_schema() -> Arc<JsonObject> {
44        let mut schema = JsonObject::new();
45        schema.insert("type".to_string(), serde_json::json!("object"));
46        schema.insert("properties".to_string(), serde_json::json!({}));
47        Arc::new(schema)
48    }
49
50    /// 创建 echo 工具的 schema
51    fn echo_schema() -> Arc<JsonObject> {
52        let mut schema = JsonObject::new();
53        schema.insert("type".to_string(), serde_json::json!("object"));
54        schema.insert(
55            "properties".to_string(),
56            serde_json::json!({
57                "message": {
58                    "type": "string",
59                    "description": "Message to echo back"
60                }
61            }),
62        );
63        schema.insert("required".to_string(), serde_json::json!(["message"]));
64        Arc::new(schema)
65    }
66
67    /// 获取工具定义列表
68    fn get_tools() -> Vec<Tool> {
69        vec![
70            Tool::new("echo", "Echo back the input message", Self::echo_schema()),
71            Tool::new(
72                "increment",
73                "Increment the counter by 1 and return new value",
74                Self::empty_schema(),
75            ),
76            Tool::new("reset", "Reset the counter to 0", Self::empty_schema()),
77            Tool::new(
78                "get_counter",
79                "Get current counter value without changing it",
80                Self::empty_schema(),
81            ),
82        ]
83    }
84
85    /// 处理 echo 工具调用
86    async fn handle_echo(&self, args: &serde_json::Value) -> CallToolResult {
87        let message = args
88            .get("message")
89            .and_then(|v| v.as_str())
90            .unwrap_or("(no message)");
91        CallToolResult::success(vec![Content::text(format!("Echo: {}", message))])
92    }
93
94    /// 处理 increment 工具调用
95    async fn handle_increment(&self) -> CallToolResult {
96        let mut counter = self.counter.lock().await;
97        *counter += 1;
98        CallToolResult::success(vec![Content::text(format!("Counter: {}", *counter))])
99    }
100
101    /// 处理 reset 工具调用
102    async fn handle_reset(&self) -> CallToolResult {
103        let mut counter = self.counter.lock().await;
104        *counter = 0;
105        CallToolResult::success(vec![Content::text("Counter reset to 0".to_string())])
106    }
107
108    /// 处理 get_counter 工具调用
109    async fn handle_get_counter(&self) -> CallToolResult {
110        let counter = self.counter.lock().await;
111        CallToolResult::success(vec![Content::text(format!("Counter: {}", *counter))])
112    }
113}
114
115impl Default for TestMcpServer {
116    fn default() -> Self {
117        Self::new()
118    }
119}
120
121impl ServerHandler for TestMcpServer {
122    fn get_info(&self) -> ServerInfo {
123        ServerInfo {
124            protocol_version: ProtocolVersion::V_2024_11_05,
125            capabilities: ServerCapabilities::builder().enable_tools().build(),
126            server_info: Implementation {
127                name: "test-mcp-server".to_string(),
128                version: env!("CARGO_PKG_VERSION").to_string(),
129                title: Some("Test MCP Server".to_string()),
130                website_url: None,
131                icons: None,
132            },
133            instructions: Some(
134                "A minimal MCP server for integration testing. \
135                 Provides echo, increment, reset, and get_counter tools."
136                    .to_string(),
137            ),
138        }
139    }
140
141    async fn list_tools(
142        &self,
143        _request: Option<PaginatedRequestParam>,
144        _context: RequestContext<RoleServer>,
145    ) -> Result<ListToolsResult, ErrorData> {
146        Ok(ListToolsResult {
147            tools: Self::get_tools(),
148            next_cursor: None,
149        })
150    }
151
152    async fn call_tool(
153        &self,
154        request: CallToolRequestParam,
155        _context: RequestContext<RoleServer>,
156    ) -> Result<CallToolResult, ErrorData> {
157        let args = request
158            .arguments
159            .as_ref()
160            .map(|v| serde_json::Value::Object(v.clone()))
161            .unwrap_or(serde_json::Value::Object(Default::default()));
162
163        // 使用 &str 来进行匹配
164        let tool_name: &str = &request.name;
165        let result = match tool_name {
166            "echo" => self.handle_echo(&args).await,
167            "increment" => self.handle_increment().await,
168            "reset" => self.handle_reset().await,
169            "get_counter" => self.handle_get_counter().await,
170            _ => CallToolResult::error(vec![Content::text(format!(
171                "Unknown tool: {}",
172                request.name
173            ))]),
174        };
175
176        Ok(result)
177    }
178}
179
180/// 独立运行时作为 stdio MCP 服务器
181#[tokio::main]
182async fn main() -> anyhow::Result<()> {
183    // 初始化日志(输出到 stderr,避免干扰 stdio 通信)
184    tracing_subscriber::fmt()
185        .with_writer(std::io::stderr)
186        .with_env_filter(
187            tracing_subscriber::EnvFilter::from_default_env()
188                .add_directive(tracing::Level::INFO.into()),
189        )
190        .init();
191
192    tracing::info!("Test MCP Server starting...");
193
194    let server = TestMcpServer::new();
195    let transport = stdio();
196
197    tracing::info!("Serving on stdio...");
198    let running = server.serve(transport).await?;
199    running.waiting().await?;
200
201    tracing::info!("Test MCP Server stopped.");
202    Ok(())
203}
204
205#[cfg(test)]
206mod tests {
207    use super::*;
208
209    #[tokio::test]
210    async fn test_counter_increment() {
211        let server = TestMcpServer::new();
212
213        // 第一次递增
214        let result = server.handle_increment().await;
215        assert!(!result.is_error.unwrap_or(false));
216
217        // 第二次递增
218        let result = server.handle_increment().await;
219        assert!(!result.is_error.unwrap_or(false));
220    }
221
222    #[tokio::test]
223    async fn test_echo() {
224        let server = TestMcpServer::new();
225        let args = serde_json::json!({"message": "hello"});
226        let result = server.handle_echo(&args).await;
227        assert!(!result.is_error.unwrap_or(false));
228    }
229
230    #[tokio::test]
231    async fn test_reset() {
232        let server = TestMcpServer::new();
233
234        // 先递增几次
235        server.handle_increment().await;
236        server.handle_increment().await;
237
238        // 重置
239        let result = server.handle_reset().await;
240        assert!(!result.is_error.unwrap_or(false));
241    }
242
243    #[tokio::test]
244    async fn test_get_tools() {
245        let tools = TestMcpServer::get_tools();
246        assert_eq!(tools.len(), 4);
247        assert!(tools.iter().any(|t| &*t.name == "echo"));
248        assert!(tools.iter().any(|t| &*t.name == "increment"));
249        assert!(tools.iter().any(|t| &*t.name == "reset"));
250        assert!(tools.iter().any(|t| &*t.name == "get_counter"));
251    }
252}