Skip to main content

ares/mcp/
server.rs

1//! MCP (Model Context Protocol) Server Implementation
2//!
3//! This module provides an MCP server implementation using the `rmcp` crate,
4//! bridging the existing ARES tools to MCP-compatible tools.
5//!
6//! # Features
7//!
8//! Enable with the `mcp` feature flag:
9//!
10//! ```toml
11//! ares = { version = "0.1", features = ["mcp"] }
12//! ```
13//!
14//! # Example
15//!
16//! ```rust,ignore
17//! use ares::mcp::McpServer;
18//!
19//! #[tokio::main]
20//! async fn main() -> anyhow::Result<()> {
21//!     McpServer::start().await?;
22//!     Ok(())
23//! }
24//! ```
25
26use rmcp::{
27    model::*,
28    service::{RequestContext, RoleServer},
29    transport::stdio,
30    ServerHandler, ServiceExt,
31};
32use serde::{Deserialize, Serialize};
33use serde_json::json;
34
35use std::sync::Arc;
36use tokio::sync::Mutex;
37
38/// Arguments for the calculator tool
39#[derive(Debug, Deserialize, Serialize)]
40pub struct CalculatorArgs {
41    /// The arithmetic operation to perform
42    pub operation: String,
43    /// First operand
44    pub a: f64,
45    /// Second operand
46    pub b: f64,
47}
48
49/// Arguments for the web search tool
50#[derive(Debug, Deserialize, Serialize)]
51pub struct WebSearchArgs {
52    /// The search query
53    pub query: String,
54    /// Maximum number of results (default: 5)
55    #[serde(default = "default_max_results")]
56    pub max_results: usize,
57}
58
59fn default_max_results() -> usize {
60    5
61}
62
63/// MCP Server implementation that bridges ARES tools to MCP
64#[derive(Clone)]
65pub struct McpServer {
66    /// Internal state for tracking operations
67    operation_count: Arc<Mutex<u64>>,
68}
69
70impl McpServer {
71    /// Create a new MCP server instance
72    pub fn new() -> Self {
73        Self {
74            operation_count: Arc::new(Mutex::new(0)),
75        }
76    }
77
78    /// Get list of available tools
79    fn get_tools() -> Vec<Tool> {
80        vec![
81            Tool {
82                name: "calculator".into(),
83                description: Some(
84                    "Perform basic arithmetic operations (add, subtract, multiply, divide)".into(),
85                ),
86                input_schema: serde_json::from_value(json!({
87                    "type": "object",
88                    "properties": {
89                        "operation": {
90                            "type": "string",
91                            "enum": ["add", "subtract", "multiply", "divide"],
92                            "description": "The arithmetic operation to perform"
93                        },
94                        "a": {
95                            "type": "number",
96                            "description": "First operand"
97                        },
98                        "b": {
99                            "type": "number",
100                            "description": "Second operand"
101                        }
102                    },
103                    "required": ["operation", "a", "b"]
104                }))
105                .unwrap_or_default(),
106                annotations: None,
107                icons: None,
108                meta: None,
109                output_schema: None,
110                title: Some("Calculator".into()),
111            },
112            Tool {
113                name: "web_search".into(),
114                description: Some(
115                    "Search the web for information using DuckDuckGo. Returns a list of search results with titles, snippets, and URLs.".into(),
116                ),
117                input_schema: serde_json::from_value(json!({
118                    "type": "object",
119                    "properties": {
120                        "query": {
121                            "type": "string",
122                            "description": "The search query"
123                        },
124                        "max_results": {
125                            "type": "integer",
126                            "description": "Maximum number of results (default: 5)",
127                            "default": 5
128                        }
129                    },
130                    "required": ["query"]
131                }))
132                .unwrap_or_default(),
133                annotations: None,
134                icons: None,
135                meta: None,
136                output_schema: None,
137                title: Some("Web Search".into()),
138            },
139            Tool {
140                name: "server_stats".into(),
141                description: Some(
142                    "Get statistics about the MCP server including operation count".into(),
143                ),
144                input_schema: serde_json::from_value(json!({
145                    "type": "object",
146                    "properties": {}
147                }))
148                .unwrap_or_default(),
149                annotations: None,
150                icons: None,
151                meta: None,
152                output_schema: None,
153                title: Some("Server Stats".into()),
154            },
155            Tool {
156                name: "echo".into(),
157                description: Some("Echo back the input message (useful for testing)".into()),
158                input_schema: serde_json::from_value(json!({
159                    "type": "object",
160                    "properties": {
161                        "message": {
162                            "type": "string",
163                            "description": "The message to echo back"
164                        }
165                    },
166                    "required": ["message"]
167                }))
168                .unwrap_or_default(),
169                annotations: None,
170                icons: None,
171                meta: None,
172                output_schema: None,
173                title: Some("Echo".into()),
174            },
175        ]
176    }
177
178    /// Execute the calculator tool
179    async fn execute_calculator(&self, args: CalculatorArgs) -> CallToolResult {
180        let mut count = self.operation_count.lock().await;
181        *count += 1;
182
183        let result = match args.operation.as_str() {
184            "add" => args.a + args.b,
185            "subtract" => args.a - args.b,
186            "multiply" => args.a * args.b,
187            "divide" => {
188                if args.b == 0.0 {
189                    return CallToolResult::error(vec![Content::text("Error: Division by zero")]);
190                }
191                args.a / args.b
192            }
193            op => {
194                return CallToolResult::error(vec![Content::text(format!(
195                    "Error: Unknown operation '{}'. Supported: add, subtract, multiply, divide",
196                    op
197                ))]);
198            }
199        };
200
201        let response = json!({
202            "operation": args.operation,
203            "a": args.a,
204            "b": args.b,
205            "result": result
206        });
207
208        CallToolResult::success(vec![Content::text(
209            serde_json::to_string_pretty(&response).unwrap_or_else(|_| result.to_string()),
210        )])
211    }
212
213    /// Execute the web search tool
214    async fn execute_web_search(&self, args: WebSearchArgs) -> CallToolResult {
215        let mut count = self.operation_count.lock().await;
216        *count += 1;
217
218        // Use daedra to perform the search
219        let search_args = daedra::types::SearchArgs {
220            query: args.query.clone(),
221            options: Some(daedra::types::SearchOptions {
222                num_results: args.max_results,
223                ..Default::default()
224            }),
225        };
226
227        match daedra::tools::search::perform_search(&search_args).await {
228            Ok(results) => {
229                let json_results: Vec<serde_json::Value> = results
230                    .data
231                    .into_iter()
232                    .map(|result| {
233                        json!({
234                            "title": result.title,
235                            "url": result.url,
236                            "snippet": result.description
237                        })
238                    })
239                    .collect();
240
241                let response = json!({
242                    "query": args.query,
243                    "results": json_results,
244                    "count": json_results.len()
245                });
246
247                CallToolResult::success(vec![Content::text(
248                    serde_json::to_string_pretty(&response)
249                        .unwrap_or_else(|_| "Search completed".to_string()),
250                )])
251            }
252            Err(e) => CallToolResult::error(vec![Content::text(format!("Search failed: {}", e))]),
253        }
254    }
255
256    /// Execute the server stats tool
257    async fn execute_server_stats(&self) -> CallToolResult {
258        let count = self.operation_count.lock().await;
259
260        let response = json!({
261            "server": "ARES MCP Server",
262            "version": env!("CARGO_PKG_VERSION"),
263            "operation_count": *count,
264            "available_tools": ["calculator", "web_search", "server_stats", "echo"]
265        });
266
267        CallToolResult::success(vec![Content::text(
268            serde_json::to_string_pretty(&response).unwrap_or_else(|_| "Stats unavailable".into()),
269        )])
270    }
271
272    /// Execute the echo tool
273    async fn execute_echo(&self, message: String) -> CallToolResult {
274        let mut count = self.operation_count.lock().await;
275        *count += 1;
276
277        CallToolResult::success(vec![Content::text(message)])
278    }
279
280    /// Execute a tool by name
281    async fn execute_tool(
282        &self,
283        name: &str,
284        arguments: Option<serde_json::Map<String, serde_json::Value>>,
285    ) -> CallToolResult {
286        let args = arguments.unwrap_or_default();
287        let args_value = serde_json::Value::Object(args);
288
289        match name {
290            "calculator" => match serde_json::from_value::<CalculatorArgs>(args_value) {
291                Ok(calc_args) => self.execute_calculator(calc_args).await,
292                Err(e) => CallToolResult::error(vec![Content::text(format!(
293                    "Invalid calculator arguments: {}",
294                    e
295                ))]),
296            },
297            "web_search" => match serde_json::from_value::<WebSearchArgs>(args_value) {
298                Ok(search_args) => self.execute_web_search(search_args).await,
299                Err(e) => CallToolResult::error(vec![Content::text(format!(
300                    "Invalid search arguments: {}",
301                    e
302                ))]),
303            },
304            "server_stats" => self.execute_server_stats().await,
305            "echo" => {
306                let message = args_value
307                    .get("message")
308                    .and_then(|v| v.as_str())
309                    .unwrap_or("")
310                    .to_string();
311                self.execute_echo(message).await
312            }
313            _ => CallToolResult::error(vec![Content::text(format!("Unknown tool: {}", name))]),
314        }
315    }
316}
317
318impl Default for McpServer {
319    fn default() -> Self {
320        Self::new()
321    }
322}
323
324/// Implement ServerHandler for MCP protocol
325impl ServerHandler for McpServer {
326    fn get_info(&self) -> ServerInfo {
327        ServerInfo {
328            protocol_version: ProtocolVersion::LATEST,
329            capabilities: ServerCapabilities::builder().enable_tools().build(),
330            server_info: Implementation::from_build_env(),
331            instructions: Some(
332                "A.R.E.S MCP Server - Provides calculator, web search, and utility tools".into(),
333            ),
334        }
335    }
336
337    async fn list_tools(
338        &self,
339        _request: Option<PaginatedRequestParam>,
340        _context: RequestContext<RoleServer>,
341    ) -> Result<ListToolsResult, rmcp::ErrorData> {
342        Ok(ListToolsResult {
343            tools: Self::get_tools(),
344            next_cursor: None,
345            meta: None,
346        })
347    }
348
349    async fn call_tool(
350        &self,
351        request: CallToolRequestParam,
352        _context: RequestContext<RoleServer>,
353    ) -> Result<CallToolResult, rmcp::ErrorData> {
354        Ok(self.execute_tool(&request.name, request.arguments).await)
355    }
356}
357
358impl McpServer {
359    /// Start the MCP server with stdio transport
360    ///
361    /// This function blocks until the server is shut down.
362    ///
363    /// # Errors
364    ///
365    /// Returns an error if the server fails to start or encounters a fatal error.
366    pub async fn start() -> crate::types::Result<()> {
367        tracing::info!("Starting A.R.E.S MCP Server v{}", env!("CARGO_PKG_VERSION"));
368
369        let server = McpServer::new();
370
371        // Serve using stdio transport (standard for MCP)
372        let service = server
373            .serve(stdio())
374            .await
375            .map_err(|e| crate::types::AppError::External(format!("MCP server error: {}", e)))?;
376
377        tracing::info!("MCP server started successfully");
378
379        // Wait for the service to complete
380        service
381            .waiting()
382            .await
383            .map_err(|e| crate::types::AppError::External(format!("MCP server error: {}", e)))?;
384
385        tracing::info!("MCP server shut down");
386        Ok(())
387    }
388}
389
390#[cfg(test)]
391mod tests {
392    use super::*;
393
394    #[test]
395    fn test_calculator_args_parsing() {
396        let json = r#"{"operation": "add", "a": 5.0, "b": 3.0}"#;
397        let args: CalculatorArgs = serde_json::from_str(json).unwrap();
398        assert_eq!(args.operation, "add");
399        assert_eq!(args.a, 5.0);
400        assert_eq!(args.b, 3.0);
401    }
402
403    #[test]
404    fn test_web_search_args_default() {
405        let json = r#"{"query": "test query"}"#;
406        let args: WebSearchArgs = serde_json::from_str(json).unwrap();
407        assert_eq!(args.query, "test query");
408        assert_eq!(args.max_results, 5); // default value
409    }
410
411    #[test]
412    fn test_web_search_args_with_max_results() {
413        let json = r#"{"query": "test query", "max_results": 10}"#;
414        let args: WebSearchArgs = serde_json::from_str(json).unwrap();
415        assert_eq!(args.query, "test query");
416        assert_eq!(args.max_results, 10);
417    }
418
419    #[test]
420    fn test_mcp_server_creation() {
421        let server = McpServer::new();
422        // Just verify it can be created
423        let _ = server;
424    }
425
426    #[test]
427    fn test_mcp_server_default() {
428        let server = McpServer::default();
429        let _ = server;
430    }
431
432    #[test]
433    fn test_get_tools() {
434        let tools = McpServer::get_tools();
435        assert_eq!(tools.len(), 4);
436
437        let tool_names: Vec<String> = tools.iter().map(|t| t.name.to_string()).collect();
438        assert!(tool_names.contains(&"calculator".to_string()));
439        assert!(tool_names.contains(&"web_search".to_string()));
440        assert!(tool_names.contains(&"server_stats".to_string()));
441        assert!(tool_names.contains(&"echo".to_string()));
442    }
443
444    #[tokio::test]
445    async fn test_calculator_add() {
446        let server = McpServer::new();
447        let args = CalculatorArgs {
448            operation: "add".to_string(),
449            a: 5.0,
450            b: 3.0,
451        };
452        let result = server.execute_calculator(args).await;
453        let content = &result.content[0];
454        if let RawContent::Text(text) = &content.raw {
455            assert!(text.text.contains("8"));
456        }
457    }
458
459    #[tokio::test]
460    async fn test_calculator_divide_by_zero() {
461        let server = McpServer::new();
462        let args = CalculatorArgs {
463            operation: "divide".to_string(),
464            a: 5.0,
465            b: 0.0,
466        };
467        let result = server.execute_calculator(args).await;
468        let content = &result.content[0];
469        if let RawContent::Text(text) = &content.raw {
470            assert!(text.text.contains("Division by zero"));
471        }
472    }
473
474    #[tokio::test]
475    async fn test_calculator_unknown_operation() {
476        let server = McpServer::new();
477        let args = CalculatorArgs {
478            operation: "unknown".to_string(),
479            a: 5.0,
480            b: 3.0,
481        };
482        let result = server.execute_calculator(args).await;
483        let content = &result.content[0];
484        if let RawContent::Text(text) = &content.raw {
485            assert!(text.text.contains("Unknown operation"));
486        }
487    }
488
489    #[tokio::test]
490    async fn test_echo() {
491        let server = McpServer::new();
492        let result = server.execute_echo("Hello, MCP!".to_string()).await;
493        let content = &result.content[0];
494        if let RawContent::Text(text) = &content.raw {
495            assert_eq!(text.text, "Hello, MCP!");
496        }
497    }
498
499    #[tokio::test]
500    async fn test_server_stats() {
501        let server = McpServer::new();
502        let result = server.execute_server_stats().await;
503        let content = &result.content[0];
504        if let RawContent::Text(text) = &content.raw {
505            assert!(text.text.contains("ARES MCP Server"));
506            assert!(text.text.contains("operation_count"));
507        }
508    }
509
510    #[tokio::test]
511    async fn test_operation_count_increments() {
512        let server = McpServer::new();
513
514        // Initial count should be 0
515        {
516            let count = server.operation_count.lock().await;
517            assert_eq!(*count, 0);
518        }
519
520        // Perform an operation
521        let _ = server.execute_echo("test".to_string()).await;
522
523        // Count should be 1
524        {
525            let count = server.operation_count.lock().await;
526            assert_eq!(*count, 1);
527        }
528
529        // Perform another operation
530        let args = CalculatorArgs {
531            operation: "add".to_string(),
532            a: 1.0,
533            b: 1.0,
534        };
535        let _ = server.execute_calculator(args).await;
536
537        // Count should be 2
538        {
539            let count = server.operation_count.lock().await;
540            assert_eq!(*count, 2);
541        }
542    }
543
544    #[tokio::test]
545    async fn test_execute_tool_calculator() {
546        let server = McpServer::new();
547        let mut args = serde_json::Map::new();
548        args.insert("operation".to_string(), json!("multiply"));
549        args.insert("a".to_string(), json!(4.0));
550        args.insert("b".to_string(), json!(3.0));
551
552        let result = server.execute_tool("calculator", Some(args)).await;
553        let content = &result.content[0];
554        if let RawContent::Text(text) = &content.raw {
555            assert!(text.text.contains("12"));
556        }
557    }
558
559    #[tokio::test]
560    async fn test_execute_tool_unknown() {
561        let server = McpServer::new();
562        let result = server.execute_tool("nonexistent", None).await;
563        let content = &result.content[0];
564        if let RawContent::Text(text) = &content.raw {
565            assert!(text.text.contains("Unknown tool"));
566        }
567    }
568}