1use rmcp::{
27 model::*,
28 service::{RequestContext, RoleServer},
29 transport::stdio,
30 ServerHandler, ServiceExt,
31};
32use serde::{Deserialize, Serialize};
33use serde_json::json;
34
35use std::sync::Arc;
36use tokio::sync::Mutex;
37
38#[derive(Debug, Deserialize, Serialize)]
40pub struct CalculatorArgs {
41 pub operation: String,
43 pub a: f64,
45 pub b: f64,
47}
48
49#[derive(Debug, Deserialize, Serialize)]
51pub struct WebSearchArgs {
52 pub query: String,
54 #[serde(default = "default_max_results")]
56 pub max_results: usize,
57}
58
59fn default_max_results() -> usize {
60 5
61}
62
63#[derive(Clone)]
65pub struct McpServer {
66 operation_count: Arc<Mutex<u64>>,
68}
69
70impl McpServer {
71 pub fn new() -> Self {
73 Self {
74 operation_count: Arc::new(Mutex::new(0)),
75 }
76 }
77
78 fn get_tools() -> Vec<Tool> {
80 vec![
81 Tool {
82 name: "calculator".into(),
83 description: Some(
84 "Perform basic arithmetic operations (add, subtract, multiply, divide)".into(),
85 ),
86 input_schema: serde_json::from_value(json!({
87 "type": "object",
88 "properties": {
89 "operation": {
90 "type": "string",
91 "enum": ["add", "subtract", "multiply", "divide"],
92 "description": "The arithmetic operation to perform"
93 },
94 "a": {
95 "type": "number",
96 "description": "First operand"
97 },
98 "b": {
99 "type": "number",
100 "description": "Second operand"
101 }
102 },
103 "required": ["operation", "a", "b"]
104 }))
105 .unwrap_or_default(),
106 annotations: None,
107 icons: None,
108 meta: None,
109 output_schema: None,
110 title: Some("Calculator".into()),
111 },
112 Tool {
113 name: "web_search".into(),
114 description: Some(
115 "Search the web for information using DuckDuckGo. Returns a list of search results with titles, snippets, and URLs.".into(),
116 ),
117 input_schema: serde_json::from_value(json!({
118 "type": "object",
119 "properties": {
120 "query": {
121 "type": "string",
122 "description": "The search query"
123 },
124 "max_results": {
125 "type": "integer",
126 "description": "Maximum number of results (default: 5)",
127 "default": 5
128 }
129 },
130 "required": ["query"]
131 }))
132 .unwrap_or_default(),
133 annotations: None,
134 icons: None,
135 meta: None,
136 output_schema: None,
137 title: Some("Web Search".into()),
138 },
139 Tool {
140 name: "server_stats".into(),
141 description: Some(
142 "Get statistics about the MCP server including operation count".into(),
143 ),
144 input_schema: serde_json::from_value(json!({
145 "type": "object",
146 "properties": {}
147 }))
148 .unwrap_or_default(),
149 annotations: None,
150 icons: None,
151 meta: None,
152 output_schema: None,
153 title: Some("Server Stats".into()),
154 },
155 Tool {
156 name: "echo".into(),
157 description: Some("Echo back the input message (useful for testing)".into()),
158 input_schema: serde_json::from_value(json!({
159 "type": "object",
160 "properties": {
161 "message": {
162 "type": "string",
163 "description": "The message to echo back"
164 }
165 },
166 "required": ["message"]
167 }))
168 .unwrap_or_default(),
169 annotations: None,
170 icons: None,
171 meta: None,
172 output_schema: None,
173 title: Some("Echo".into()),
174 },
175 ]
176 }
177
178 async fn execute_calculator(&self, args: CalculatorArgs) -> CallToolResult {
180 let mut count = self.operation_count.lock().await;
181 *count += 1;
182
183 let result = match args.operation.as_str() {
184 "add" => args.a + args.b,
185 "subtract" => args.a - args.b,
186 "multiply" => args.a * args.b,
187 "divide" => {
188 if args.b == 0.0 {
189 return CallToolResult::error(vec![Content::text("Error: Division by zero")]);
190 }
191 args.a / args.b
192 }
193 op => {
194 return CallToolResult::error(vec![Content::text(format!(
195 "Error: Unknown operation '{}'. Supported: add, subtract, multiply, divide",
196 op
197 ))]);
198 }
199 };
200
201 let response = json!({
202 "operation": args.operation,
203 "a": args.a,
204 "b": args.b,
205 "result": result
206 });
207
208 CallToolResult::success(vec![Content::text(
209 serde_json::to_string_pretty(&response).unwrap_or_else(|_| result.to_string()),
210 )])
211 }
212
213 async fn execute_web_search(&self, args: WebSearchArgs) -> CallToolResult {
215 let mut count = self.operation_count.lock().await;
216 *count += 1;
217
218 let search_args = daedra::types::SearchArgs {
220 query: args.query.clone(),
221 options: Some(daedra::types::SearchOptions {
222 num_results: args.max_results,
223 ..Default::default()
224 }),
225 };
226
227 match daedra::tools::search::perform_search(&search_args).await {
228 Ok(results) => {
229 let json_results: Vec<serde_json::Value> = results
230 .data
231 .into_iter()
232 .map(|result| {
233 json!({
234 "title": result.title,
235 "url": result.url,
236 "snippet": result.description
237 })
238 })
239 .collect();
240
241 let response = json!({
242 "query": args.query,
243 "results": json_results,
244 "count": json_results.len()
245 });
246
247 CallToolResult::success(vec![Content::text(
248 serde_json::to_string_pretty(&response)
249 .unwrap_or_else(|_| "Search completed".to_string()),
250 )])
251 }
252 Err(e) => CallToolResult::error(vec![Content::text(format!("Search failed: {}", e))]),
253 }
254 }
255
256 async fn execute_server_stats(&self) -> CallToolResult {
258 let count = self.operation_count.lock().await;
259
260 let response = json!({
261 "server": "ARES MCP Server",
262 "version": env!("CARGO_PKG_VERSION"),
263 "operation_count": *count,
264 "available_tools": ["calculator", "web_search", "server_stats", "echo"]
265 });
266
267 CallToolResult::success(vec![Content::text(
268 serde_json::to_string_pretty(&response).unwrap_or_else(|_| "Stats unavailable".into()),
269 )])
270 }
271
272 async fn execute_echo(&self, message: String) -> CallToolResult {
274 let mut count = self.operation_count.lock().await;
275 *count += 1;
276
277 CallToolResult::success(vec![Content::text(message)])
278 }
279
280 async fn execute_tool(
282 &self,
283 name: &str,
284 arguments: Option<serde_json::Map<String, serde_json::Value>>,
285 ) -> CallToolResult {
286 let args = arguments.unwrap_or_default();
287 let args_value = serde_json::Value::Object(args);
288
289 match name {
290 "calculator" => match serde_json::from_value::<CalculatorArgs>(args_value) {
291 Ok(calc_args) => self.execute_calculator(calc_args).await,
292 Err(e) => CallToolResult::error(vec![Content::text(format!(
293 "Invalid calculator arguments: {}",
294 e
295 ))]),
296 },
297 "web_search" => match serde_json::from_value::<WebSearchArgs>(args_value) {
298 Ok(search_args) => self.execute_web_search(search_args).await,
299 Err(e) => CallToolResult::error(vec![Content::text(format!(
300 "Invalid search arguments: {}",
301 e
302 ))]),
303 },
304 "server_stats" => self.execute_server_stats().await,
305 "echo" => {
306 let message = args_value
307 .get("message")
308 .and_then(|v| v.as_str())
309 .unwrap_or("")
310 .to_string();
311 self.execute_echo(message).await
312 }
313 _ => CallToolResult::error(vec![Content::text(format!("Unknown tool: {}", name))]),
314 }
315 }
316}
317
318impl Default for McpServer {
319 fn default() -> Self {
320 Self::new()
321 }
322}
323
324impl ServerHandler for McpServer {
326 fn get_info(&self) -> ServerInfo {
327 ServerInfo {
328 protocol_version: ProtocolVersion::LATEST,
329 capabilities: ServerCapabilities::builder().enable_tools().build(),
330 server_info: Implementation::from_build_env(),
331 instructions: Some(
332 "A.R.E.S MCP Server - Provides calculator, web search, and utility tools".into(),
333 ),
334 }
335 }
336
337 async fn list_tools(
338 &self,
339 _request: Option<PaginatedRequestParam>,
340 _context: RequestContext<RoleServer>,
341 ) -> Result<ListToolsResult, rmcp::ErrorData> {
342 Ok(ListToolsResult {
343 tools: Self::get_tools(),
344 next_cursor: None,
345 meta: None,
346 })
347 }
348
349 async fn call_tool(
350 &self,
351 request: CallToolRequestParam,
352 _context: RequestContext<RoleServer>,
353 ) -> Result<CallToolResult, rmcp::ErrorData> {
354 Ok(self.execute_tool(&request.name, request.arguments).await)
355 }
356}
357
358impl McpServer {
359 pub async fn start() -> crate::types::Result<()> {
367 tracing::info!("Starting A.R.E.S MCP Server v{}", env!("CARGO_PKG_VERSION"));
368
369 let server = McpServer::new();
370
371 let service = server
373 .serve(stdio())
374 .await
375 .map_err(|e| crate::types::AppError::External(format!("MCP server error: {}", e)))?;
376
377 tracing::info!("MCP server started successfully");
378
379 service
381 .waiting()
382 .await
383 .map_err(|e| crate::types::AppError::External(format!("MCP server error: {}", e)))?;
384
385 tracing::info!("MCP server shut down");
386 Ok(())
387 }
388}
389
390#[cfg(test)]
391mod tests {
392 use super::*;
393
394 #[test]
395 fn test_calculator_args_parsing() {
396 let json = r#"{"operation": "add", "a": 5.0, "b": 3.0}"#;
397 let args: CalculatorArgs = serde_json::from_str(json).unwrap();
398 assert_eq!(args.operation, "add");
399 assert_eq!(args.a, 5.0);
400 assert_eq!(args.b, 3.0);
401 }
402
403 #[test]
404 fn test_web_search_args_default() {
405 let json = r#"{"query": "test query"}"#;
406 let args: WebSearchArgs = serde_json::from_str(json).unwrap();
407 assert_eq!(args.query, "test query");
408 assert_eq!(args.max_results, 5); }
410
411 #[test]
412 fn test_web_search_args_with_max_results() {
413 let json = r#"{"query": "test query", "max_results": 10}"#;
414 let args: WebSearchArgs = serde_json::from_str(json).unwrap();
415 assert_eq!(args.query, "test query");
416 assert_eq!(args.max_results, 10);
417 }
418
419 #[test]
420 fn test_mcp_server_creation() {
421 let server = McpServer::new();
422 let _ = server;
424 }
425
426 #[test]
427 fn test_mcp_server_default() {
428 let server = McpServer::default();
429 let _ = server;
430 }
431
432 #[test]
433 fn test_get_tools() {
434 let tools = McpServer::get_tools();
435 assert_eq!(tools.len(), 4);
436
437 let tool_names: Vec<String> = tools.iter().map(|t| t.name.to_string()).collect();
438 assert!(tool_names.contains(&"calculator".to_string()));
439 assert!(tool_names.contains(&"web_search".to_string()));
440 assert!(tool_names.contains(&"server_stats".to_string()));
441 assert!(tool_names.contains(&"echo".to_string()));
442 }
443
444 #[tokio::test]
445 async fn test_calculator_add() {
446 let server = McpServer::new();
447 let args = CalculatorArgs {
448 operation: "add".to_string(),
449 a: 5.0,
450 b: 3.0,
451 };
452 let result = server.execute_calculator(args).await;
453 let content = &result.content[0];
454 if let RawContent::Text(text) = &content.raw {
455 assert!(text.text.contains("8"));
456 }
457 }
458
459 #[tokio::test]
460 async fn test_calculator_divide_by_zero() {
461 let server = McpServer::new();
462 let args = CalculatorArgs {
463 operation: "divide".to_string(),
464 a: 5.0,
465 b: 0.0,
466 };
467 let result = server.execute_calculator(args).await;
468 let content = &result.content[0];
469 if let RawContent::Text(text) = &content.raw {
470 assert!(text.text.contains("Division by zero"));
471 }
472 }
473
474 #[tokio::test]
475 async fn test_calculator_unknown_operation() {
476 let server = McpServer::new();
477 let args = CalculatorArgs {
478 operation: "unknown".to_string(),
479 a: 5.0,
480 b: 3.0,
481 };
482 let result = server.execute_calculator(args).await;
483 let content = &result.content[0];
484 if let RawContent::Text(text) = &content.raw {
485 assert!(text.text.contains("Unknown operation"));
486 }
487 }
488
489 #[tokio::test]
490 async fn test_echo() {
491 let server = McpServer::new();
492 let result = server.execute_echo("Hello, MCP!".to_string()).await;
493 let content = &result.content[0];
494 if let RawContent::Text(text) = &content.raw {
495 assert_eq!(text.text, "Hello, MCP!");
496 }
497 }
498
499 #[tokio::test]
500 async fn test_server_stats() {
501 let server = McpServer::new();
502 let result = server.execute_server_stats().await;
503 let content = &result.content[0];
504 if let RawContent::Text(text) = &content.raw {
505 assert!(text.text.contains("ARES MCP Server"));
506 assert!(text.text.contains("operation_count"));
507 }
508 }
509
510 #[tokio::test]
511 async fn test_operation_count_increments() {
512 let server = McpServer::new();
513
514 {
516 let count = server.operation_count.lock().await;
517 assert_eq!(*count, 0);
518 }
519
520 let _ = server.execute_echo("test".to_string()).await;
522
523 {
525 let count = server.operation_count.lock().await;
526 assert_eq!(*count, 1);
527 }
528
529 let args = CalculatorArgs {
531 operation: "add".to_string(),
532 a: 1.0,
533 b: 1.0,
534 };
535 let _ = server.execute_calculator(args).await;
536
537 {
539 let count = server.operation_count.lock().await;
540 assert_eq!(*count, 2);
541 }
542 }
543
544 #[tokio::test]
545 async fn test_execute_tool_calculator() {
546 let server = McpServer::new();
547 let mut args = serde_json::Map::new();
548 args.insert("operation".to_string(), json!("multiply"));
549 args.insert("a".to_string(), json!(4.0));
550 args.insert("b".to_string(), json!(3.0));
551
552 let result = server.execute_tool("calculator", Some(args)).await;
553 let content = &result.content[0];
554 if let RawContent::Text(text) = &content.raw {
555 assert!(text.text.contains("12"));
556 }
557 }
558
559 #[tokio::test]
560 async fn test_execute_tool_unknown() {
561 let server = McpServer::new();
562 let result = server.execute_tool("nonexistent", None).await;
563 let content = &result.content[0];
564 if let RawContent::Text(text) = &content.raw {
565 assert!(text.text.contains("Unknown tool"));
566 }
567 }
568}