1use anyhow::Error;
3use std::sync::{Arc, Mutex};
4use std::collections::HashMap;
5use crate::tools::Tool;
6use serde::{Deserialize, Serialize};
7use axum::{
8 extract::State,
9 response::Json,
10 routing::{get, post},
11 Router,
12};
13use tokio::net::TcpListener;
14use tower_http::cors::CorsLayer;
15use serde_json::Value;
16use log::{info, error};
17
18use crate::mcp::JSONRPCRequest;
19use crate::mcp::JSONRPCResponse;
20use crate::mcp::JSONRPCError;
21
22#[derive(Debug, Deserialize, Serialize)]
23struct CallToolParams {
24 name: String,
25 arguments: Option<std::collections::HashMap<String, serde_json::Value>>,
26}
27
28pub struct SimpleMcpServer {
30 address: String,
31 tools: Arc<Mutex<HashMap<String, Arc<dyn Tool>>>>,
32 is_running: Arc<Mutex<bool>>,
33 server_handle: Arc<Mutex<Option<tokio::task::JoinHandle<()>>>>,
34}
35
36impl SimpleMcpServer {
37 pub fn new() -> Self {
38 Self {
39 address: "127.0.0.1:6000".to_string(),
40 tools: Arc::new(Mutex::new(HashMap::new())),
41 is_running: Arc::new(Mutex::new(false)),
42 server_handle: Arc::new(Mutex::new(None)),
43 }
44 }
45
46 pub fn with_address(mut self, address: String) -> Self {
47 self.address = address;
48 self
49 }
50}
51
52#[axum::debug_handler]
54async fn test_handler() -> &'static str {
55 "Hello, Rust-Agent!"
56}
57
58#[axum::debug_handler]
60async fn handle_jsonrpc_request(
61 State(state): State<Arc<SimpleMcpServerState>>,
62 Json(payload): Json<JSONRPCRequest>,
63) -> Json<JSONRPCResponse> {
64 let response = match payload.method.as_str() {
65 "tools/call" => {
66 match handle_tool_call(state, payload.params).await {
68 Ok(result) => {
69 JSONRPCResponse {
70 jsonrpc: "2.0".to_string(),
71 id: Some(payload.id.unwrap_or(Value::Null)),
72 result: Some(result),
73 error: None,
74 }
75 }
76 Err(e) => {
77 JSONRPCResponse {
78 jsonrpc: "2.0".to_string(),
79 id: Some(payload.id.unwrap_or(Value::Null)),
80 result: None,
81 error: Some(JSONRPCError {
82 code: -32603,
83 message: e.to_string(),
84 }),
85 }
86 }
87 }
88 }
89 "ping" => {
90 JSONRPCResponse {
92 jsonrpc: "2.0".to_string(),
93 id: Some(payload.id.unwrap_or(Value::Null)),
94 result: Some(Value::Object(serde_json::Map::new())),
95 error: None,
96 }
97 }
98 "tools/list" => {
99 match handle_list_tools(state).await {
101 Ok(result) => {
102 JSONRPCResponse {
103 jsonrpc: "2.0".to_string(),
104 id: Some(payload.id.unwrap_or(Value::Null)),
105 result: Some(result),
106 error: None,
107 }
108 }
109 Err(e) => {
110 JSONRPCResponse {
111 jsonrpc: "2.0".to_string(),
112 id: Some(payload.id.unwrap_or(Value::Null)),
113 result: None,
114 error: Some(JSONRPCError {
115 code: -32603,
116 message: e.to_string(),
117 }),
118 }
119 }
120 }
121 }
122 _ => {
123 JSONRPCResponse {
125 jsonrpc: "2.0".to_string(),
126 id: Some(payload.id.unwrap_or(Value::Null)),
127 result: None,
128 error: Some(JSONRPCError {
129 code: -32601,
130 message: "Method not found".to_string(),
131 }),
132 }
133 }
134 };
135
136 Json(response)
137}
138
139async fn handle_list_tools(
140 state: Arc<SimpleMcpServerState>,
141) -> Result<serde_json::Value, Error> {
142 let tools_map = state.tools.lock().map_err(|e| Error::msg(format!("Failed to acquire lock: {}", e)))?;
144
145 let mut tools_list = Vec::new();
147 for (_, tool) in tools_map.iter() {
148 let mcp_tool = serde_json::json!({
149 "name": tool.name(),
150 "description": tool.description(),
151 "inputSchema": {
152 "type": "object",
153 "properties": {},
154 "required": []
155 }
156 });
157 tools_list.push(mcp_tool);
158 }
159
160 let result = serde_json::json!({
162 "tools": tools_list
163 });
164
165 Ok(result)
166}
167
168async fn handle_tool_call(
169 state: Arc<SimpleMcpServerState>,
170 params: Option<serde_json::Value>,
171) -> Result<serde_json::Value, Error> {
172 let call_params: CallToolParams = serde_json::from_value(params.unwrap_or(serde_json::Value::Null))
174 .map_err(|e| Error::msg(format!("Invalid parameters: {}", e)))?;
175
176 let tool = {
178 let tools = state.tools.lock().map_err(|e| Error::msg(format!("Failed to acquire lock: {}", e)))?;
179 tools.get(&call_params.name)
180 .ok_or_else(|| Error::msg(format!("Tool '{}' not found", call_params.name)))?
181 .clone()
182 };
183
184 let input_str = if let Some(args) = call_params.arguments {
186 serde_json::to_string(&args)?
187 } else {
188 "{}".to_string()
189 };
190
191 let result = tool.invoke(&input_str).await?;
193 Ok(serde_json::Value::String(result))
194}
195
196#[derive(Clone)]
198struct SimpleMcpServerState {
199 tools: Arc<Mutex<HashMap<String, Arc<dyn Tool>>>>,
200}
201
202#[async_trait::async_trait]
204pub trait McpServer: Send + Sync {
205 async fn start(&self, address: &str) -> Result<(), Error>;
207
208 fn register_tool(&self, tool: Arc<dyn Tool>) -> Result<(), Error>;
210
211 async fn stop(&self) -> Result<(), Error>;
213}
214
215#[async_trait::async_trait]
216impl McpServer for SimpleMcpServer {
217 async fn start(&self, address: &str) -> Result<(), Error> {
219 info!("Starting MCP server on {}", address);
220
221 let state = Arc::new(SimpleMcpServerState {
223 tools: self.tools.clone(),
224 });
225
226 let app = Router::new()
228 .route("/rpc", post(handle_jsonrpc_request))
229 .route("/test", get(test_handler))
230 .with_state(state)
231 .layer(CorsLayer::permissive()); let listener = TcpListener::bind(address).await
235 .map_err(|e| Error::msg(format!("Failed to bind to address {}: {}", address, e)))?;
236
237 info!("MCP server listening on http://{}", address);
238
239 let handle = tokio::spawn(async move {
241 if let Err(e) = axum::serve(listener, app.into_make_service()).await {
242 error!("Server error: {}", e);
243 }
244 });
245
246 {
248 let mut is_running = self.is_running.lock().map_err(|e| Error::msg(format!("Failed to acquire lock: {}", e)))?;
249 *is_running = true;
250 }
251
252 {
254 let mut server_handle = self.server_handle.lock().map_err(|e| Error::msg(format!("Failed to acquire lock: {}", e)))?;
255 *server_handle = Some(handle);
256 }
257
258 Ok(())
259 }
260
261 fn register_tool(&self, tool: Arc<dyn Tool>) -> Result<(), Error> {
263 let name = tool.name().to_string();
264 let mut tools = self.tools.lock().map_err(|e| Error::msg(format!("Failed to acquire lock: {}", e)))?;
265 tools.insert(name, tool);
266 Ok(())
267 }
268
269 async fn stop(&self) -> Result<(), Error> {
271 info!("Stopping MCP server");
272
273 {
275 let mut is_running = self.is_running.lock().map_err(|e| Error::msg(format!("Failed to acquire lock: {}", e)))?;
276 *is_running = false;
277 }
278
279 {
281 let mut server_handle = self.server_handle.lock().map_err(|e| Error::msg(format!("Failed to acquire lock: {}", e)))?;
282 if let Some(handle) = server_handle.take() {
283 handle.abort();
284 }
285 }
286
287 Ok(())
288 }
289}