reasonkit/mcp/
server.rs

1//! MCP Server Implementation
2//!
3//! Base server trait and concrete implementation for MCP servers.
4
5use super::lifecycle::InitializeResult;
6use super::tools::{Tool, ToolHandler, ToolResult};
7use super::transport::Transport;
8use super::types::*;
9use crate::error::{Error, Result};
10use async_trait::async_trait;
11use chrono::{DateTime, Utc};
12use serde::{Deserialize, Serialize};
13use std::collections::HashMap;
14use std::sync::Arc;
15use tokio::sync::RwLock;
16use uuid::Uuid;
17
18/// Server status
19#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
20#[serde(rename_all = "snake_case")]
21pub enum ServerStatus {
22    /// Server is starting up
23    Starting,
24    /// Server is running and healthy
25    Running,
26    /// Server is degraded (responding slowly)
27    Degraded,
28    /// Server is not responding
29    Unhealthy,
30    /// Server is shutting down
31    Stopping,
32    /// Server has stopped
33    Stopped,
34    /// Server encountered a fatal error
35    Failed,
36}
37
38/// Server metrics
39#[derive(Debug, Clone, Serialize, Deserialize)]
40pub struct ServerMetrics {
41    /// Total requests handled
42    pub requests_total: u64,
43    /// Total errors encountered
44    pub errors_total: u64,
45    /// Average response time (ms)
46    pub avg_response_time_ms: f64,
47    /// Last successful request timestamp
48    pub last_success_at: Option<DateTime<Utc>>,
49    /// Last error timestamp
50    pub last_error_at: Option<DateTime<Utc>>,
51    /// Uptime in seconds
52    pub uptime_secs: u64,
53}
54
55impl Default for ServerMetrics {
56    fn default() -> Self {
57        Self {
58            requests_total: 0,
59            errors_total: 0,
60            avg_response_time_ms: 0.0,
61            last_success_at: None,
62            last_error_at: None,
63            uptime_secs: 0,
64        }
65    }
66}
67
68/// MCP server trait
69#[async_trait]
70pub trait McpServerTrait: Send + Sync {
71    /// Get server information
72    async fn server_info(&self) -> ServerInfo;
73
74    /// Get server capabilities
75    async fn capabilities(&self) -> ServerCapabilities;
76
77    /// Initialize the server
78    async fn initialize(&mut self, params: serde_json::Value) -> Result<serde_json::Value>;
79
80    /// Shutdown the server
81    async fn shutdown(&mut self) -> Result<()>;
82
83    /// Send a request to the server
84    async fn send_request(&self, request: McpRequest) -> Result<McpResponse>;
85
86    /// Send a notification to the server (no response expected)
87    async fn send_notification(&self, notification: McpNotification) -> Result<()>;
88
89    /// Get current server status
90    async fn status(&self) -> ServerStatus;
91
92    /// Get server metrics
93    async fn metrics(&self) -> ServerMetrics;
94
95    /// Perform a health check
96    async fn health_check(&self) -> Result<bool>;
97
98    /// List available tools
99    async fn list_tools(&self) -> Vec<Tool>;
100
101    /// Call a tool
102    async fn call_tool(
103        &self,
104        name: &str,
105        arguments: HashMap<String, serde_json::Value>,
106    ) -> Result<ToolResult>;
107
108    /// Register a tool
109    async fn register_tool(&self, tool: Tool, handler: Arc<dyn ToolHandler>);
110}
111
112type ToolRegistry = HashMap<String, (Tool, Arc<dyn ToolHandler>)>;
113
114/// Concrete MCP server implementation
115pub struct McpServer {
116    /// Server ID
117    pub id: Uuid,
118    /// Server name
119    pub name: String,
120    /// Server information
121    pub info: ServerInfo,
122    /// Server capabilities
123    pub capabilities: ServerCapabilities,
124    /// Transport layer
125    transport: Arc<dyn Transport>,
126    /// Current status
127    status: Arc<RwLock<ServerStatus>>,
128    /// Server metrics
129    metrics: Arc<RwLock<ServerMetrics>>,
130    /// Server started at
131    started_at: DateTime<Utc>,
132    /// Registered tools
133    tools: Arc<RwLock<ToolRegistry>>,
134}
135
136impl McpServer {
137    /// Create a new MCP server
138    pub fn new(
139        name: impl Into<String>,
140        info: ServerInfo,
141        capabilities: ServerCapabilities,
142        transport: Arc<dyn Transport>,
143    ) -> Self {
144        Self {
145            id: Uuid::new_v4(),
146            name: name.into(),
147            info,
148            capabilities,
149            transport,
150            status: Arc::new(RwLock::new(ServerStatus::Starting)),
151            metrics: Arc::new(RwLock::new(ServerMetrics::default())),
152            started_at: Utc::now(),
153            tools: Arc::new(RwLock::new(HashMap::new())),
154        }
155    }
156
157    /// Update server status
158    pub async fn set_status(&self, status: ServerStatus) {
159        let mut s = self.status.write().await;
160        *s = status;
161    }
162
163    /// Record a successful request
164    pub async fn record_success(&self, response_time_ms: f64) {
165        let mut m = self.metrics.write().await;
166        m.requests_total += 1;
167        m.last_success_at = Some(Utc::now());
168
169        // Update average response time (exponential moving average)
170        if m.requests_total == 1 {
171            m.avg_response_time_ms = response_time_ms;
172        } else {
173            m.avg_response_time_ms = (m.avg_response_time_ms * 0.9) + (response_time_ms * 0.1);
174        }
175    }
176
177    /// Record an error
178    pub async fn record_error(&self) {
179        let mut m = self.metrics.write().await;
180        m.errors_total += 1;
181        m.last_error_at = Some(Utc::now());
182    }
183
184    /// Get uptime in seconds
185    pub fn uptime_secs(&self) -> u64 {
186        (Utc::now() - self.started_at).num_seconds() as u64
187    }
188}
189
190#[async_trait]
191impl McpServerTrait for McpServer {
192    async fn server_info(&self) -> ServerInfo {
193        self.info.clone()
194    }
195
196    async fn capabilities(&self) -> ServerCapabilities {
197        self.capabilities.clone()
198    }
199
200    async fn initialize(&mut self, params: serde_json::Value) -> Result<serde_json::Value> {
201        let request = McpRequest::new(
202            RequestId::String(Uuid::new_v4().to_string()),
203            "initialize",
204            Some(params),
205        );
206
207        let start = std::time::Instant::now();
208        let response = self
209            .transport
210            .send_request(request)
211            .await
212            .map_err(|e| Error::network(format!("Initialize failed: {}", e)))?;
213
214        let elapsed_ms = start.elapsed().as_millis() as f64;
215
216        if let Some(error) = response.error {
217            self.record_error().await;
218            return Err(Error::network(format!(
219                "Initialize error: {}",
220                error.message
221            )));
222        }
223
224        self.record_success(elapsed_ms).await;
225        self.set_status(ServerStatus::Running).await;
226
227        response
228            .result
229            .ok_or_else(|| Error::network("Initialize response missing result"))
230    }
231
232    async fn shutdown(&mut self) -> Result<()> {
233        self.set_status(ServerStatus::Stopping).await;
234
235        let request = McpRequest::new(
236            RequestId::String(Uuid::new_v4().to_string()),
237            "shutdown",
238            None,
239        );
240
241        let response = self
242            .transport
243            .send_request(request)
244            .await
245            .map_err(|e| Error::network(format!("Shutdown failed: {}", e)))?;
246
247        if response.error.is_some() {
248            self.set_status(ServerStatus::Failed).await;
249        } else {
250            self.set_status(ServerStatus::Stopped).await;
251        }
252
253        Ok(())
254    }
255
256    async fn send_request(&self, request: McpRequest) -> Result<McpResponse> {
257        let start = std::time::Instant::now();
258
259        let response = self
260            .transport
261            .send_request(request)
262            .await
263            .map_err(|e| Error::network(format!("Request failed: {}", e)))?;
264
265        let elapsed_ms = start.elapsed().as_millis() as f64;
266
267        if response.error.is_some() {
268            self.record_error().await;
269        } else {
270            self.record_success(elapsed_ms).await;
271        }
272
273        Ok(response)
274    }
275
276    async fn send_notification(&self, notification: McpNotification) -> Result<()> {
277        self.transport
278            .send_notification(notification)
279            .await
280            .map_err(|e| Error::network(format!("Notification failed: {}", e)))
281    }
282
283    async fn status(&self) -> ServerStatus {
284        *self.status.read().await
285    }
286
287    async fn metrics(&self) -> ServerMetrics {
288        let mut m = self.metrics.read().await.clone();
289        m.uptime_secs = self.uptime_secs();
290        m
291    }
292
293    async fn health_check(&self) -> Result<bool> {
294        let request = McpRequest::new(RequestId::String(Uuid::new_v4().to_string()), "ping", None);
295
296        match tokio::time::timeout(
297            std::time::Duration::from_secs(5),
298            self.transport.send_request(request),
299        )
300        .await
301        {
302            Ok(Ok(response)) => {
303                if response.error.is_none() {
304                    self.set_status(ServerStatus::Running).await;
305                    Ok(true)
306                } else {
307                    self.set_status(ServerStatus::Degraded).await;
308                    Ok(false)
309                }
310            }
311            Ok(Err(_)) | Err(_) => {
312                self.set_status(ServerStatus::Unhealthy).await;
313                Ok(false)
314            }
315        }
316    }
317
318    async fn list_tools(&self) -> Vec<Tool> {
319        let tools = self.tools.read().await;
320        tools.values().map(|(t, _)| t.clone()).collect()
321    }
322
323    async fn call_tool(
324        &self,
325        name: &str,
326        arguments: HashMap<String, serde_json::Value>,
327    ) -> Result<ToolResult> {
328        let handler = {
329            let tools = self.tools.read().await;
330            tools.get(name).map(|(_, handler)| Arc::clone(handler))
331        };
332
333        match handler {
334            Some(handler) => handler.call(arguments).await,
335            None => Err(Error::Mcp(format!("Tool not found: {}", name))),
336        }
337    }
338
339    async fn register_tool(&self, tool: Tool, handler: Arc<dyn ToolHandler>) {
340        let mut tools = self.tools.write().await;
341        tools.insert(tool.name.clone(), (tool, handler));
342    }
343}
344
345/// Server-side Stdio Transport (uses current process stdin/stdout)
346pub struct ServerStdioTransport {
347    stdout: tokio::sync::Mutex<tokio::io::Stdout>,
348}
349
350impl ServerStdioTransport {
351    pub fn new() -> Self {
352        Self {
353            stdout: tokio::sync::Mutex::new(tokio::io::stdout()),
354        }
355    }
356}
357
358impl Default for ServerStdioTransport {
359    fn default() -> Self {
360        Self::new()
361    }
362}
363
364#[async_trait]
365impl Transport for ServerStdioTransport {
366    async fn send_request(&self, _request: McpRequest) -> std::io::Result<McpResponse> {
367        // Server sending request to client (e.g. sampling) - not implemented yet
368        Err(std::io::Error::new(
369            std::io::ErrorKind::Unsupported,
370            "Server-to-client requests not supported yet",
371        ))
372    }
373
374    async fn send_notification(&self, notification: McpNotification) -> std::io::Result<()> {
375        use tokio::io::AsyncWriteExt;
376        let json = serde_json::to_string(&notification)
377            .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
378
379        let mut stdout = self.stdout.lock().await;
380        stdout.write_all(json.as_bytes()).await?;
381        stdout.write_all(b"\n").await?;
382        stdout.flush().await?;
383        Ok(())
384    }
385
386    async fn close(&self) -> std::io::Result<()> {
387        Ok(())
388    }
389}
390
391/// Run the MCP server
392pub async fn run_server() -> Result<()> {
393    use tokio::io::AsyncBufReadExt;
394    use tokio::io::AsyncWriteExt;
395
396    let transport = Arc::new(ServerStdioTransport::new());
397
398    let info = ServerInfo {
399        name: "reasonkit-core".to_string(),
400        version: env!("CARGO_PKG_VERSION").to_string(),
401        description: Some("ReasonKit Core MCP Server".to_string()),
402        vendor: Some("ReasonKit Team".to_string()),
403    };
404
405    let capabilities = ServerCapabilities {
406        logging: Some(LoggingCapability {}),
407        prompts: Some(PromptsCapability { list_changed: true }),
408        resources: Some(ResourcesCapability {
409            subscribe: true,
410            list_changed: true,
411        }),
412        tools: Some(ToolsCapability { list_changed: true }),
413    };
414
415    let server = McpServer::new("reasonkit-core", info, capabilities, transport.clone());
416
417    // Register ThinkTools
418    crate::mcp::thinktool_tools::register_thinktools(&server).await;
419
420    // Main loop
421    let stdin = tokio::io::stdin();
422    let mut reader = tokio::io::BufReader::new(stdin);
423    let mut line = String::new();
424
425    eprintln!("ReasonKit Core MCP Server running on stdio...");
426
427    loop {
428        line.clear();
429        let bytes_read = reader
430            .read_line(&mut line)
431            .await
432            .map_err(|e| Error::network(format!("Failed to read line: {}", e)))?;
433
434        if bytes_read == 0 {
435            break; // EOF
436        }
437
438        let msg: serde_json::Value = match serde_json::from_str(&line) {
439            Ok(m) => m,
440            Err(e) => {
441                eprintln!("Failed to parse JSON: {}", e);
442
443                let response = serde_json::json!({
444                    "jsonrpc": "2.0",
445                    "id": serde_json::Value::Null,
446                    "error": {
447                        "code": ErrorCode::PARSE_ERROR.0,
448                        "message": e.to_string()
449                    }
450                });
451
452                let response_str = serde_json::to_string(&response).map_err(Error::from)?;
453                let mut stdout = transport.stdout.lock().await;
454                stdout
455                    .write_all(response_str.as_bytes())
456                    .await
457                    .map_err(Error::from)?;
458                stdout.write_all(b"\n").await.map_err(Error::from)?;
459                stdout.flush().await.map_err(Error::from)?;
460                continue;
461            }
462        };
463
464        // Handle JSON-RPC message
465        if let Some(method) = msg.get("method").and_then(|m| m.as_str()) {
466            // It's a request or notification
467            if let Some(id) = msg.get("id") {
468                // Request
469                let result: Result<serde_json::Value> = match method {
470                    "initialize" => {
471                        server.set_status(ServerStatus::Running).await;
472                        let init =
473                            InitializeResult::new(server.info.clone(), server.capabilities.clone());
474                        Ok(serde_json::to_value(init).map_err(Error::from)?)
475                    }
476                    "shutdown" => {
477                        server.set_status(ServerStatus::Stopping).await;
478                        Ok(serde_json::json!(null))
479                    }
480                    "ping" => Ok(serde_json::json!({})),
481                    "tools/list" => {
482                        let tools = server.list_tools().await;
483                        Ok(serde_json::json!({ "tools": tools }))
484                    }
485                    "tools/call" => {
486                        let params = match msg.get("params").and_then(|p| p.as_object()) {
487                            Some(params) => params,
488                            None => Err(Error::Mcp("Invalid params".to_string()))?,
489                        };
490
491                        let name = match params.get("name").and_then(|v| v.as_str()) {
492                            Some(name) => name,
493                            None => Err(Error::Mcp("Missing tool name".to_string()))?,
494                        };
495
496                        let args: std::result::Result<HashMap<String, serde_json::Value>, Error> =
497                            match params.get("arguments") {
498                                Some(v) if v.is_object() => {
499                                    serde_json::from_value(v.clone()).map_err(Error::from)
500                                }
501                                Some(_) => Err(Error::Mcp("Invalid tool arguments".to_string())),
502                                None => Ok(HashMap::new()),
503                            };
504
505                        match args {
506                            Ok(args) => match server.call_tool(name, args).await {
507                                Ok(res) => serde_json::to_value(res).map_err(Error::from),
508                                Err(e) => Err(e),
509                            },
510                            Err(e) => Err(e),
511                        }
512                    }
513                    _ => Err(Error::Mcp(format!("Method not found: {}", method))),
514                };
515
516                let (code, message) = match &result {
517                    Err(Error::Mcp(message)) if message.starts_with("Method not found:") => {
518                        (ErrorCode::METHOD_NOT_FOUND.0, message.clone())
519                    }
520                    Err(Error::Mcp(message)) if message.starts_with("Tool not found:") => {
521                        (ErrorCode::TOOL_NOT_FOUND.0, message.clone())
522                    }
523                    Err(Error::Mcp(message)) => (ErrorCode::INVALID_PARAMS.0, message.clone()),
524                    Err(e) => (ErrorCode::INTERNAL_ERROR.0, e.to_string()),
525                    Ok(_) => (0, String::new()),
526                };
527
528                let response = match result {
529                    Ok(res) => serde_json::json!({
530                        "jsonrpc": "2.0",
531                        "id": id,
532                        "result": res
533                    }),
534                    Err(_) => serde_json::json!({
535                        "jsonrpc": "2.0",
536                        "id": id,
537                        "error": {
538                            "code": code,
539                            "message": message
540                        }
541                    }),
542                };
543
544                let json_line = serde_json::to_string(&response).map_err(Error::from)?;
545                let mut stdout = transport.stdout.lock().await;
546                stdout
547                    .write_all(json_line.as_bytes())
548                    .await
549                    .map_err(Error::from)?;
550                stdout.write_all(b"\n").await.map_err(Error::from)?;
551                stdout.flush().await.map_err(Error::from)?;
552
553                if method == "shutdown" {
554                    server.set_status(ServerStatus::Stopped).await;
555                    break;
556                }
557            } else {
558                // Notification
559                if method == "notifications/initialized" {
560                    // Handle initialized notification
561                }
562            }
563        }
564    }
565
566    Ok(())
567}
568
569#[cfg(test)]
570mod tests {
571    use super::*;
572    use crate::mcp::tools::ToolResultContent;
573
574    #[test]
575    fn test_server_status() {
576        let status = ServerStatus::Running;
577        let json = serde_json::to_string(&status).unwrap();
578        assert_eq!(json, "\"running\"");
579    }
580
581    #[test]
582    fn test_metrics_default() {
583        let metrics = ServerMetrics::default();
584        assert_eq!(metrics.requests_total, 0);
585        assert_eq!(metrics.errors_total, 0);
586    }
587
588    #[tokio::test]
589    async fn test_tool_execution() {
590        let transport = Arc::new(ServerStdioTransport::new());
591        let info = ServerInfo {
592            name: "test".to_string(),
593            version: "1.0".to_string(),
594            description: None,
595            vendor: None,
596        };
597        let capabilities = ServerCapabilities::default();
598        let server = McpServer::new("test", info, capabilities, transport);
599
600        struct EchoTool;
601        #[async_trait]
602        impl ToolHandler for EchoTool {
603            async fn call(&self, args: HashMap<String, serde_json::Value>) -> Result<ToolResult> {
604                let msg = args
605                    .get("message")
606                    .and_then(|v| v.as_str())
607                    .unwrap_or("default");
608                Ok(ToolResult::text(format!("Echo: {}", msg)))
609            }
610        }
611
612        server
613            .register_tool(Tool::simple("echo", "Echoes back"), Arc::new(EchoTool))
614            .await;
615
616        let tools = server.list_tools().await;
617        assert_eq!(tools.len(), 1);
618        assert_eq!(tools[0].name, "echo");
619
620        let mut args = HashMap::new();
621        args.insert("message".to_string(), serde_json::json!("hello"));
622
623        let result = server.call_tool("echo", args).await.unwrap();
624        match &result.content[0] {
625            ToolResultContent::Text { text } => assert_eq!(text, "Echo: hello"),
626            _ => panic!("Wrong content type"),
627        }
628    }
629}