Skip to main content

crates_docs/server/
handler.rs

1//! MCP 请求处理器实现
2//!
3//! 提供 MCP 协议请求的处理逻辑,包括工具列表、工具调用、资源列表等。
4//!
5//! # 主要结构体
6//!
7//! - `CratesDocsHandler`: 标准 MCP 处理器
8//! - `CratesDocsHandlerCore`: 核心处理器(提供更细粒度的控制)
9
10use crate::server::CratesDocsServer;
11use crate::tools::ToolRegistry;
12use async_trait::async_trait;
13use rust_mcp_sdk::{
14    mcp_server::{ServerHandler, ServerHandlerCore},
15    schema::{
16        CallToolError, CallToolRequestParams, CallToolResult, GetPromptRequestParams,
17        GetPromptResult, ListPromptsResult, ListResourcesResult, ListToolsResult,
18        NotificationFromClient, PaginatedRequestParams, ReadResourceRequestParams,
19        ReadResourceResult, RequestFromClient, ResultFromServer, RpcError,
20    },
21    McpServer,
22};
23use std::sync::Arc;
24use tracing::{info_span, Instrument};
25use uuid::Uuid;
26
27/// MCP 服务器处理器
28///
29/// 实现标准 MCP 协议处理器接口,处理客户端请求。
30///
31/// # 字段
32///
33/// - `server`: 服务器实例的 Arc 引用
34pub struct CratesDocsHandler {
35    server: Arc<CratesDocsServer>,
36}
37
38impl CratesDocsHandler {
39    /// 创建新的处理器
40    ///
41    /// # 参数
42    ///
43    /// * `server` - 服务器实例
44    ///
45    /// # 示例
46    ///
47    /// ```rust,no_run
48    /// use std::sync::Arc;
49    /// use crates_docs::server::{CratesDocsServer, CratesDocsHandler};
50    /// use crates_docs::AppConfig;
51    ///
52    /// let config = AppConfig::default();
53    /// let server = Arc::new(CratesDocsServer::new(config).unwrap());
54    /// let handler = CratesDocsHandler::new(server);
55    /// ```
56    #[must_use]
57    pub fn new(server: Arc<CratesDocsServer>) -> Self {
58        Self { server }
59    }
60
61    /// 获取工具注册表
62    fn tool_registry(&self) -> &ToolRegistry {
63        self.server.tool_registry()
64    }
65}
66
67#[async_trait]
68impl ServerHandler for CratesDocsHandler {
69    /// Handle list tools request
70    async fn handle_list_tools_request(
71        &self,
72        _request: Option<PaginatedRequestParams>,
73        _runtime: std::sync::Arc<dyn McpServer>,
74    ) -> std::result::Result<ListToolsResult, RpcError> {
75        let trace_id = Uuid::new_v4().to_string();
76        let span = info_span!(
77            "list_tools",
78            trace_id = %trace_id,
79        );
80
81        async {
82            tracing::debug!("Listing available tools");
83            let tools = self.tool_registry().get_tools();
84            tracing::debug!("Found {} tools", tools.len());
85
86            Ok(ListToolsResult {
87                tools,
88                meta: None,
89                next_cursor: None,
90            })
91        }
92        .instrument(span)
93        .await
94    }
95
96    /// Handle call tool request
97    async fn handle_call_tool_request(
98        &self,
99        params: CallToolRequestParams,
100        _runtime: std::sync::Arc<dyn McpServer>,
101    ) -> std::result::Result<CallToolResult, CallToolError> {
102        let trace_id = Uuid::new_v4().to_string();
103        let tool_name = params.name.clone();
104        let span = info_span!(
105            "call_tool",
106            trace_id = %trace_id,
107            tool = %tool_name,
108        );
109
110        async {
111            tracing::info!("Executing tool: {}", tool_name);
112            let start = std::time::Instant::now();
113
114            let result = self
115                .tool_registry()
116                .execute_tool(
117                    &tool_name,
118                    params
119                        .arguments
120                        .map_or_else(|| serde_json::Value::Null, serde_json::Value::Object),
121                )
122                .await;
123
124            let duration = start.elapsed();
125            match &result {
126                Ok(_) => {
127                    tracing::info!("Tool {} executed successfully in {:?}", tool_name, duration);
128                }
129                Err(e) => {
130                    tracing::error!(
131                        "Tool {} execution failed after {:?}: {:?}",
132                        tool_name,
133                        duration,
134                        e
135                    );
136                }
137            }
138
139            result
140        }
141        .instrument(span)
142        .await
143    }
144
145    /// Handle list resources request
146    async fn handle_list_resources_request(
147        &self,
148        _request: Option<PaginatedRequestParams>,
149        _runtime: std::sync::Arc<dyn McpServer>,
150    ) -> std::result::Result<ListResourcesResult, RpcError> {
151        // Resources are not currently provided
152        Ok(ListResourcesResult {
153            resources: vec![],
154            meta: None,
155            next_cursor: None,
156        })
157    }
158
159    /// Handle read resource request
160    async fn handle_read_resource_request(
161        &self,
162        _params: ReadResourceRequestParams,
163        _runtime: std::sync::Arc<dyn McpServer>,
164    ) -> std::result::Result<ReadResourceResult, RpcError> {
165        // Resources are not currently provided
166        Err(RpcError::invalid_request().with_message("Resource not found".to_string()))
167    }
168
169    /// Handle list prompts request
170    async fn handle_list_prompts_request(
171        &self,
172        _request: Option<PaginatedRequestParams>,
173        _runtime: std::sync::Arc<dyn McpServer>,
174    ) -> std::result::Result<ListPromptsResult, RpcError> {
175        // Prompts are not currently provided
176        Ok(ListPromptsResult {
177            prompts: vec![],
178            meta: None,
179            next_cursor: None,
180        })
181    }
182
183    /// Handle get prompt request
184    async fn handle_get_prompt_request(
185        &self,
186        _params: GetPromptRequestParams,
187        _runtime: std::sync::Arc<dyn McpServer>,
188    ) -> std::result::Result<GetPromptResult, RpcError> {
189        // Prompts are not currently provided
190        Err(RpcError::invalid_request().with_message("Prompt not found".to_string()))
191    }
192}
193
194/// Core handler implementation (provides more control)
195pub struct CratesDocsHandlerCore {
196    server: Arc<CratesDocsServer>,
197}
198
199impl CratesDocsHandlerCore {
200    /// Create a new core handler
201    #[must_use]
202    pub fn new(server: Arc<CratesDocsServer>) -> Self {
203        Self { server }
204    }
205
206    async fn execute_tool_request(&self, params: CallToolRequestParams) -> ResultFromServer {
207        let trace_id = Uuid::new_v4().to_string();
208        let tool_name = params.name.clone();
209        let span = info_span!(
210            "execute_tool_core",
211            trace_id = %trace_id,
212            tool = %tool_name,
213        );
214
215        async {
216            tracing::info!("Executing tool request: {}", tool_name);
217            let start = std::time::Instant::now();
218
219            let result = self
220                .server
221                .tool_registry()
222                .execute_tool(
223                    &tool_name,
224                    params
225                        .arguments
226                        .map_or_else(|| serde_json::Value::Null, serde_json::Value::Object),
227                )
228                .await;
229
230            let duration = start.elapsed();
231            match &result {
232                Ok(_) => {
233                    tracing::info!("Tool {} executed successfully in {:?}", tool_name, duration);
234                }
235                Err(e) => {
236                    tracing::error!(
237                        "Tool {} execution failed after {:?}: {:?}",
238                        tool_name,
239                        duration,
240                        e
241                    );
242                }
243            }
244
245            result.unwrap_or_else(CallToolResult::from).into()
246        }
247        .instrument(span)
248        .await
249    }
250}
251
252#[async_trait]
253impl ServerHandlerCore for CratesDocsHandlerCore {
254    /// Handle request
255    async fn handle_request(
256        &self,
257        request: RequestFromClient,
258        _runtime: std::sync::Arc<dyn McpServer>,
259    ) -> std::result::Result<ResultFromServer, RpcError> {
260        match request {
261            RequestFromClient::ListToolsRequest(_params) => {
262                let tools = self.server.tool_registry().get_tools();
263                Ok(ListToolsResult {
264                    tools,
265                    meta: None,
266                    next_cursor: None,
267                }
268                .into())
269            }
270            RequestFromClient::CallToolRequest(params) => {
271                Ok(self.execute_tool_request(params).await)
272            }
273            RequestFromClient::ListResourcesRequest(_params) => Ok(ListResourcesResult {
274                resources: vec![],
275                meta: None,
276                next_cursor: None,
277            }
278            .into()),
279            RequestFromClient::ReadResourceRequest(_params) => {
280                Err(RpcError::invalid_request().with_message("Resource not found".to_string()))
281            }
282            RequestFromClient::ListPromptsRequest(_params) => Ok(ListPromptsResult {
283                prompts: vec![],
284                meta: None,
285                next_cursor: None,
286            }
287            .into()),
288            RequestFromClient::GetPromptRequest(_params) => {
289                Err(RpcError::invalid_request().with_message("Prompt not found".to_string()))
290            }
291            RequestFromClient::InitializeRequest(_params) => {
292                // Use default initialization handling
293                Err(RpcError::method_not_found()
294                    .with_message("Initialize request should be handled by runtime".to_string()))
295            }
296            _ => {
297                // Other requests use default handling
298                Err(RpcError::method_not_found()
299                    .with_message("Unimplemented request type".to_string()))
300            }
301        }
302    }
303
304    /// Handle notification
305    async fn handle_notification(
306        &self,
307        _notification: NotificationFromClient,
308        _runtime: std::sync::Arc<dyn McpServer>,
309    ) -> std::result::Result<(), RpcError> {
310        // Notifications are not currently handled
311        Ok(())
312    }
313
314    /// Handle error
315    async fn handle_error(
316        &self,
317        _error: &RpcError,
318        _runtime: std::sync::Arc<dyn McpServer>,
319    ) -> std::result::Result<(), RpcError> {
320        // Log error but don't interrupt
321        tracing::error!("MCP error: {:?}", _error);
322        Ok(())
323    }
324}
325
326#[cfg(test)]
327mod tests {
328    use super::CratesDocsHandlerCore;
329    use crate::server::CratesDocsServer;
330    use rust_mcp_sdk::schema::{CallToolRequestParams, CallToolResult, ContentBlock};
331    use std::sync::Arc;
332
333    #[tokio::test]
334    async fn test_execute_tool_request_preserves_tool_errors() {
335        let server = Arc::new(CratesDocsServer::new(crate::AppConfig::default()).unwrap());
336        let handler = CratesDocsHandlerCore::new(server);
337        let result = handler
338            .execute_tool_request(CallToolRequestParams {
339                arguments: Some(serde_json::Map::from_iter([(
340                    "verbose".to_string(),
341                    serde_json::Value::String("bad".to_string()),
342                )])),
343                meta: None,
344                name: "health_check".to_string(),
345                task: None,
346            })
347            .await;
348
349        let result = CallToolResult::try_from(result).unwrap();
350        assert_eq!(result.is_error, Some(true));
351
352        let Some(ContentBlock::TextContent(text)) = result.content.first() else {
353            panic!("expected first content block to be text");
354        };
355
356        assert!(text.text.contains("health_check"));
357        assert!(text.text.contains("Parameter parsing failed"));
358        assert!(!text.text.contains("Unknown tool"));
359    }
360}