rust_agent/mcp/
server.rs

1// MCP server abstract definition
2use anyhow::Error;
3use std::sync::{Arc, Mutex};
4use std::collections::HashMap;
5use crate::tools::Tool;
6use serde::{Deserialize, Serialize};
7use axum::{
8    extract::State,
9    response::Json,
10    routing::{get, post},
11    Router,
12};
13use tokio::net::TcpListener;
14use tower_http::cors::CorsLayer;
15use serde_json::Value;
16use log::{info, error};
17
18use crate::mcp::JSONRPCRequest;
19use crate::mcp::JSONRPCResponse;
20use crate::mcp::JSONRPCError;
21
22#[derive(Debug, Deserialize, Serialize)]
23struct CallToolParams {
24    name: String,
25    arguments: Option<std::collections::HashMap<String, serde_json::Value>>,
26}
27
28// MCP server implementation
29pub struct SimpleMcpServer {
30    address: String,
31    tools: Arc<Mutex<HashMap<String, Arc<dyn Tool>>>>,
32    is_running: Arc<Mutex<bool>>,
33    server_handle: Arc<Mutex<Option<tokio::task::JoinHandle<()>>>>,
34}
35
36impl SimpleMcpServer {
37    pub fn new() -> Self {
38        Self {
39            address: "127.0.0.1:6000".to_string(),
40            tools: Arc::new(Mutex::new(HashMap::new())),
41            is_running: Arc::new(Mutex::new(false)),
42            server_handle: Arc::new(Mutex::new(None)),
43        }
44    }
45    
46    pub fn with_address(mut self, address: String) -> Self {
47        self.address = address;
48        self
49    }
50}
51
52// Simple test handler
53#[axum::debug_handler]
54async fn test_handler() -> &'static str {
55    "Hello, Rust-Agent!"
56}
57
58// Handle JSON-RPC request
59#[axum::debug_handler]
60async fn handle_jsonrpc_request(
61    State(state): State<Arc<SimpleMcpServerState>>,
62    Json(payload): Json<JSONRPCRequest>,
63) -> Json<JSONRPCResponse> {
64    let response = match payload.method.as_str() {
65        "tools/call" => {
66            // Handle tool call request
67            match handle_tool_call(state, payload.params).await {
68                Ok(result) => {
69                    JSONRPCResponse {
70                        jsonrpc: "2.0".to_string(),
71                        id: Some(payload.id.unwrap_or(Value::Null)),
72                        result: Some(result),
73                        error: None,
74                    }
75                }
76                Err(e) => {
77                    JSONRPCResponse {
78                        jsonrpc: "2.0".to_string(),
79                        id: Some(payload.id.unwrap_or(Value::Null)),
80                        result: None,
81                        error: Some(JSONRPCError {
82                            code: -32603,
83                            message: e.to_string(),
84                        }),
85                    }
86                }
87            }
88        }
89        "ping" => {
90            // Handle ping request
91            JSONRPCResponse {
92                jsonrpc: "2.0".to_string(),
93                id: Some(payload.id.unwrap_or(Value::Null)),
94                result: Some(Value::Object(serde_json::Map::new())),
95                error: None,
96            }
97        }
98        "tools/list" => {
99            // Handle tool list request
100            match handle_list_tools(state).await {
101                Ok(result) => {
102                    JSONRPCResponse {
103                        jsonrpc: "2.0".to_string(),
104                        id: Some(payload.id.unwrap_or(Value::Null)),
105                        result: Some(result),
106                        error: None,
107                    }
108                }
109                Err(e) => {
110                    JSONRPCResponse {
111                        jsonrpc: "2.0".to_string(),
112                        id: Some(payload.id.unwrap_or(Value::Null)),
113                        result: None,
114                        error: Some(JSONRPCError {
115                            code: -32603,
116                            message: e.to_string(),
117                        }),
118                    }
119                }
120            }
121        }
122        _ => {
123            // Unsupported method
124            JSONRPCResponse {
125                jsonrpc: "2.0".to_string(),
126                id: Some(payload.id.unwrap_or(Value::Null)),
127                result: None,
128                error: Some(JSONRPCError {
129                    code: -32601,
130                    message: "Method not found".to_string(),
131                }),
132            }
133        }
134    };
135    
136    Json(response)
137}
138
139async fn handle_list_tools(
140    state: Arc<SimpleMcpServerState>,
141) -> Result<serde_json::Value, Error> {
142    // Get all registered tools
143    let tools_map = state.tools.lock().map_err(|e| Error::msg(format!("Failed to acquire lock: {}", e)))?;
144    
145    // Convert to tool format required by MCP protocol
146    let mut tools_list = Vec::new();
147    for (_, tool) in tools_map.iter() {
148        let mcp_tool = serde_json::json!({
149            "name": tool.name(),
150            "description": tool.description(),
151            "inputSchema": {
152                "type": "object",
153                "properties": {},
154                "required": []
155            }
156        });
157        tools_list.push(mcp_tool);
158    }
159    
160    // Construct response
161    let result = serde_json::json!({
162        "tools": tools_list
163    });
164    
165    Ok(result)
166}
167
168async fn handle_tool_call(
169    state: Arc<SimpleMcpServerState>,
170    params: Option<serde_json::Value>,
171) -> Result<serde_json::Value, Error> {
172    // Parse parameters
173    let call_params: CallToolParams = serde_json::from_value(params.unwrap_or(serde_json::Value::Null))
174        .map_err(|e| Error::msg(format!("Invalid parameters: {}", e)))?;
175    
176    // Find tool and get its Arc reference
177    let tool = {
178        let tools = state.tools.lock().map_err(|e| Error::msg(format!("Failed to acquire lock: {}", e)))?;
179        tools.get(&call_params.name)
180            .ok_or_else(|| Error::msg(format!("Tool '{}' not found", call_params.name)))?
181            .clone()
182    };
183    
184    // Prepare tool input parameters
185    let input_str = if let Some(args) = call_params.arguments {
186        serde_json::to_string(&args)?
187    } else {
188        "{}".to_string()
189    };
190    
191    // Call tool (now can be called without holding the lock)
192    let result = tool.invoke(&input_str).await?;
193    Ok(serde_json::Value::String(result))
194}
195
196// Server state structure
197#[derive(Clone)]
198struct SimpleMcpServerState {
199    tools: Arc<Mutex<HashMap<String, Arc<dyn Tool>>>>,
200}
201
202// MCP server abstraction
203#[async_trait::async_trait]
204pub trait McpServer: Send + Sync {
205    // Start MCP server
206    async fn start(&self, address: &str) -> Result<(), Error>;
207    
208    // Register tool to MCP server
209    fn register_tool(&self, tool: Arc<dyn Tool>) -> Result<(), Error>;
210    
211    // Stop MCP server
212    async fn stop(&self) -> Result<(), Error>;
213}
214
215#[async_trait::async_trait]
216impl McpServer for SimpleMcpServer {
217    // Start MCP server
218    async fn start(&self, address: &str) -> Result<(), Error> {
219        info!("Starting MCP server on {}", address);
220        
221        // Create server state
222        let state = Arc::new(SimpleMcpServerState {
223            tools: self.tools.clone(),
224        });
225        
226        // Create routes
227        let app = Router::new()
228            .route("/rpc", post(handle_jsonrpc_request))
229            .route("/test", get(test_handler))
230            .with_state(state)
231            .layer(CorsLayer::permissive()); // Allow all CORS requests
232        
233        // Start server
234        let listener = TcpListener::bind(address).await
235            .map_err(|e| Error::msg(format!("Failed to bind to address {}: {}", address, e)))?;
236        
237        info!("MCP server listening on http://{}", address);
238        
239        // Run server in background task
240        let handle = tokio::spawn(async move {
241            if let Err(e) = axum::serve(listener, app.into_make_service()).await {
242                error!("Server error: {}", e);
243            }
244        });
245        
246        // Update server status
247        {
248            let mut is_running = self.is_running.lock().map_err(|e| Error::msg(format!("Failed to acquire lock: {}", e)))?;
249            *is_running = true;
250        }
251        
252        // Save server handle
253        {
254            let mut server_handle = self.server_handle.lock().map_err(|e| Error::msg(format!("Failed to acquire lock: {}", e)))?;
255            *server_handle = Some(handle);
256        }
257        
258        Ok(())
259    }
260    
261    // Register tool to MCP server
262    fn register_tool(&self, tool: Arc<dyn Tool>) -> Result<(), Error> {
263        let name = tool.name().to_string();
264        let mut tools = self.tools.lock().map_err(|e| Error::msg(format!("Failed to acquire lock: {}", e)))?;
265        tools.insert(name, tool);
266        Ok(())
267    }
268    
269    // Stop MCP server
270    async fn stop(&self) -> Result<(), Error> {
271        info!("Stopping MCP server");
272        
273        // Update server status
274        {
275            let mut is_running = self.is_running.lock().map_err(|e| Error::msg(format!("Failed to acquire lock: {}", e)))?;
276            *is_running = false;
277        }
278        
279        // Cancel server task
280        {
281            let mut server_handle = self.server_handle.lock().map_err(|e| Error::msg(format!("Failed to acquire lock: {}", e)))?;
282            if let Some(handle) = server_handle.take() {
283                handle.abort();
284            }
285        }
286        
287        Ok(())
288    }
289}