allframe_mcp/
stdio.rs

1//! STDIO transport for MCP Server with debugging support
2//!
3//! This module provides a production-ready stdio transport with:
4//! - Structured logging via tracing
5//! - Request/response tracing for debugging
6//! - Graceful shutdown handling
7//! - Built-in diagnostic tools
8//!
9//! # Usage
10//!
11//! ```rust,no_run
12//! use allframe_core::router::Router;
13//! use allframe_mcp::{McpServer, StdioTransport, StdioConfig};
14//!
15//! #[tokio::main]
16//! async fn main() {
17//!     let router = Router::new();
18//!     let mcp = McpServer::new(router);
19//!
20//!     let config = StdioConfig::default()
21//!         .with_debug_tool(true);
22//!
23//!     StdioTransport::new(mcp, config)
24//!         .serve()
25//!         .await;
26//! }
27//! ```
28
29use std::io::{stdin, stdout, BufRead, Write};
30use std::sync::atomic::{AtomicU64, Ordering};
31use std::time::Instant;
32
33use serde_json::{json, Value};
34
35use crate::McpServer;
36
37/// Configuration for the STDIO transport
38#[derive(Debug, Clone)]
39pub struct StdioConfig {
40    /// Server name reported in initialize response
41    pub server_name: String,
42    /// Server version reported in initialize response
43    pub server_version: String,
44    /// Protocol version to advertise
45    pub protocol_version: String,
46    /// Whether to include the allframe/debug tool
47    pub include_debug_tool: bool,
48    /// Log file path (if set, logs go to file instead of stderr)
49    pub log_file: Option<String>,
50}
51
52impl Default for StdioConfig {
53    fn default() -> Self {
54        Self {
55            server_name: "allframe-mcp".to_string(),
56            server_version: env!("CARGO_PKG_VERSION").to_string(),
57            protocol_version: "2024-11-05".to_string(),
58            include_debug_tool: false,
59            log_file: std::env::var("ALLFRAME_MCP_LOG_FILE").ok(),
60        }
61    }
62}
63
64impl StdioConfig {
65    /// Enable the built-in debug tool
66    pub fn with_debug_tool(mut self, enabled: bool) -> Self {
67        self.include_debug_tool = enabled;
68        self
69    }
70
71    /// Set the server name
72    pub fn with_server_name(mut self, name: impl Into<String>) -> Self {
73        self.server_name = name.into();
74        self
75    }
76
77    /// Set a log file path
78    pub fn with_log_file(mut self, path: impl Into<String>) -> Self {
79        self.log_file = Some(path.into());
80        self
81    }
82}
83
84/// STDIO transport for MCP server with debugging support
85pub struct StdioTransport {
86    mcp: McpServer,
87    config: StdioConfig,
88    start_time: Instant,
89    request_count: AtomicU64,
90}
91
92impl StdioTransport {
93    /// Create a new STDIO transport
94    pub fn new(mcp: McpServer, config: StdioConfig) -> Self {
95        Self {
96            mcp,
97            config,
98            start_time: Instant::now(),
99            request_count: AtomicU64::new(0),
100        }
101    }
102
103    /// Serve MCP protocol over stdio
104    pub async fn serve(self) {
105        self.log_startup();
106
107        let stdin = stdin();
108        let mut stdout = stdout();
109
110        // Set up shutdown signal handling
111        let shutdown = async {
112            #[cfg(unix)]
113            {
114                use tokio::signal::unix::{signal, SignalKind};
115                let mut sigterm = signal(SignalKind::terminate()).ok();
116                let mut sigint = signal(SignalKind::interrupt()).ok();
117
118                tokio::select! {
119                    _ = async { if let Some(ref mut s) = sigterm { s.recv().await } else { std::future::pending().await } } => {
120                        self.log_info("Received SIGTERM");
121                    }
122                    _ = async { if let Some(ref mut s) = sigint { s.recv().await } else { std::future::pending().await } } => {
123                        self.log_info("Received SIGINT");
124                    }
125                }
126            }
127            #[cfg(not(unix))]
128            {
129                tokio::signal::ctrl_c().await.ok();
130                self.log_info("Received shutdown signal");
131            }
132        };
133
134        // Run the main loop with shutdown handling
135        tokio::select! {
136            _ = self.run_loop(&stdin, &mut stdout) => {}
137            _ = shutdown => {
138                self.log_info("Shutting down gracefully");
139            }
140        }
141
142        self.log_shutdown();
143    }
144
145    async fn run_loop(&self, stdin: &std::io::Stdin, stdout: &mut std::io::Stdout) {
146        for line in stdin.lock().lines() {
147            let line = match line {
148                Ok(l) => l,
149                Err(e) => {
150                    self.log_error(&format!("Error reading line: {}", e));
151                    continue;
152                }
153            };
154
155            // Skip empty lines
156            if line.trim().is_empty() {
157                continue;
158            }
159
160            self.request_count.fetch_add(1, Ordering::SeqCst);
161            let request_id = self.request_count.load(Ordering::SeqCst);
162
163            self.log_request(request_id, &line);
164
165            // Parse request
166            let request: Value = match serde_json::from_str(&line) {
167                Ok(r) => r,
168                Err(e) => {
169                    self.log_error(&format!("Parse error: {}", e));
170                    let error = json!({
171                        "jsonrpc": "2.0",
172                        "error": {
173                            "code": -32700,
174                            "message": "Parse error"
175                        },
176                        "id": null
177                    });
178                    self.write_response(stdout, &error, request_id);
179                    continue;
180                }
181            };
182
183            // Handle request
184            let response = self.handle_request(request).await;
185
186            // Check if this was a notification (no response needed)
187            if let Some(resp) = response {
188                self.write_response(stdout, &resp, request_id);
189            }
190        }
191    }
192
193    fn write_response(&self, stdout: &mut std::io::Stdout, response: &Value, request_id: u64) {
194        match serde_json::to_string(&response) {
195            Ok(json_str) => {
196                self.log_response(request_id, &json_str);
197                if let Err(e) = writeln!(stdout, "{}", json_str) {
198                    self.log_error(&format!("Error writing response: {}", e));
199                }
200                if let Err(e) = stdout.flush() {
201                    self.log_error(&format!("Error flushing stdout: {}", e));
202                }
203            }
204            Err(e) => {
205                self.log_error(&format!("Error serializing response: {}", e));
206            }
207        }
208    }
209
210    async fn handle_request(&self, request: Value) -> Option<Value> {
211        let method = request["method"].as_str().unwrap_or("");
212        let id = request.get("id").cloned();
213
214        // Handle notifications (no id = notification, no response needed)
215        match method {
216            // Notifications that don't require responses
217            "initialized" | "notifications/initialized" => {
218                self.log_info("Client initialized connection");
219                return None;
220            }
221            "notifications/cancelled" => {
222                self.log_info("Request cancelled by client");
223                return None;
224            }
225            _ => {}
226        }
227
228        let result = match method {
229            // Initialize
230            "initialize" => {
231                self.log_info("Initializing MCP connection");
232                json!({
233                    "protocolVersion": self.config.protocol_version,
234                    "capabilities": {
235                        "tools": {}
236                    },
237                    "serverInfo": {
238                        "name": self.config.server_name,
239                        "version": self.config.server_version
240                    }
241                })
242            }
243
244            // List available tools
245            "tools/list" => {
246                let mut tools: Vec<Value> = self.mcp.list_tools().await.iter().map(|t| {
247                    json!({
248                        "name": t.name,
249                        "description": t.description,
250                        "inputSchema": serde_json::from_str::<Value>(&t.input_schema)
251                            .unwrap_or_else(|_| json!({"type": "object"}))
252                    })
253                }).collect();
254
255                // Add debug tool if enabled
256                if self.config.include_debug_tool {
257                    tools.push(json!({
258                        "name": "allframe/debug",
259                        "description": "Get AllFrame MCP server diagnostics and status information",
260                        "inputSchema": {
261                            "type": "object",
262                            "properties": {},
263                            "additionalProperties": false
264                        }
265                    }));
266                }
267
268                json!({ "tools": tools })
269            }
270
271            // Call a tool
272            "tools/call" => {
273                let params = &request["params"];
274                let name = params["name"].as_str().unwrap_or("");
275                let arguments = params.get("arguments").cloned().unwrap_or(json!({}));
276
277                self.log_info(&format!("Calling tool: {}", name));
278
279                // Handle built-in debug tool
280                if name == "allframe/debug" && self.config.include_debug_tool {
281                    let diagnostics = self.get_diagnostics();
282                    return Some(json!({
283                        "jsonrpc": "2.0",
284                        "result": {
285                            "content": [{
286                                "type": "text",
287                                "text": serde_json::to_string_pretty(&diagnostics).unwrap()
288                            }]
289                        },
290                        "id": id
291                    }));
292                }
293
294                match self.mcp.call_tool(name, arguments).await {
295                    Ok(result) => {
296                        json!({
297                            "content": [{
298                                "type": "text",
299                                "text": result.to_string()
300                            }]
301                        })
302                    }
303                    Err(e) => {
304                        self.log_error(&format!("Tool error: {}", e));
305                        json!({
306                            "isError": true,
307                            "content": [{
308                                "type": "text",
309                                "text": format!("Error: {}", e)
310                            }]
311                        })
312                    }
313                }
314            }
315
316            // Ping
317            "ping" => {
318                json!({})
319            }
320
321            // Unknown method
322            _ => {
323                self.log_warn(&format!("Unknown method: {}", method));
324                return Some(json!({
325                    "jsonrpc": "2.0",
326                    "error": {
327                        "code": -32601,
328                        "message": format!("Method not found: {}", method)
329                    },
330                    "id": id
331                }));
332            }
333        };
334
335        // Return successful response
336        Some(json!({
337            "jsonrpc": "2.0",
338            "result": result,
339            "id": id
340        }))
341    }
342
343    fn get_diagnostics(&self) -> Value {
344        json!({
345            "server": {
346                "name": self.config.server_name,
347                "version": self.config.server_version,
348                "protocol_version": self.config.protocol_version
349            },
350            "runtime": {
351                "uptime_seconds": self.start_time.elapsed().as_secs(),
352                "request_count": self.request_count.load(Ordering::SeqCst),
353                "tool_count": self.mcp.tool_count(),
354                "pid": std::process::id()
355            },
356            "build": {
357                "pkg_version": env!("CARGO_PKG_VERSION"),
358                "debug_tool_enabled": self.config.include_debug_tool
359            }
360        })
361    }
362
363    // Logging methods that work with or without tracing feature
364
365    fn log_startup(&self) {
366        let msg = format!(
367            "MCP Server starting: name={}, version={}, pid={}, tools={}",
368            self.config.server_name,
369            self.config.server_version,
370            std::process::id(),
371            self.mcp.tool_count()
372        );
373
374        #[cfg(feature = "tracing")]
375        tracing::info!("{}", msg);
376
377        #[cfg(not(feature = "tracing"))]
378        eprintln!("[INFO] {}", msg);
379    }
380
381    fn log_shutdown(&self) {
382        let msg = format!(
383            "MCP Server shutting down: uptime={}s, requests={}",
384            self.start_time.elapsed().as_secs(),
385            self.request_count.load(Ordering::SeqCst)
386        );
387
388        #[cfg(feature = "tracing")]
389        tracing::info!("{}", msg);
390
391        #[cfg(not(feature = "tracing"))]
392        eprintln!("[INFO] {}", msg);
393    }
394
395    fn log_request(&self, id: u64, content: &str) {
396        // Truncate long requests for logging
397        let truncated = if content.len() > 500 {
398            format!("{}...(truncated)", &content[..500])
399        } else {
400            content.to_string()
401        };
402
403        #[cfg(feature = "tracing")]
404        tracing::debug!(request_id = id, request = %truncated, "Received MCP request");
405
406        #[cfg(not(feature = "tracing"))]
407        if std::env::var("ALLFRAME_MCP_DEBUG").is_ok() {
408            eprintln!("[DEBUG] req#{}: {}", id, truncated);
409        }
410    }
411
412    fn log_response(&self, id: u64, content: &str) {
413        let truncated = if content.len() > 500 {
414            format!("{}...(truncated)", &content[..500])
415        } else {
416            content.to_string()
417        };
418
419        #[cfg(feature = "tracing")]
420        tracing::debug!(request_id = id, response = %truncated, "Sending MCP response");
421
422        #[cfg(not(feature = "tracing"))]
423        if std::env::var("ALLFRAME_MCP_DEBUG").is_ok() {
424            eprintln!("[DEBUG] res#{}: {}", id, truncated);
425        }
426    }
427
428    fn log_info(&self, msg: &str) {
429        #[cfg(feature = "tracing")]
430        tracing::info!("{}", msg);
431
432        #[cfg(not(feature = "tracing"))]
433        eprintln!("[INFO] {}", msg);
434    }
435
436    fn log_warn(&self, msg: &str) {
437        #[cfg(feature = "tracing")]
438        tracing::warn!("{}", msg);
439
440        #[cfg(not(feature = "tracing"))]
441        eprintln!("[WARN] {}", msg);
442    }
443
444    fn log_error(&self, msg: &str) {
445        #[cfg(feature = "tracing")]
446        tracing::error!("{}", msg);
447
448        #[cfg(not(feature = "tracing"))]
449        eprintln!("[ERROR] {}", msg);
450    }
451}
452
453/// Initialize tracing with file output if ALLFRAME_MCP_LOG_FILE is set
454#[cfg(feature = "tracing")]
455pub fn init_tracing() {
456    use tracing_subscriber::EnvFilter;
457
458    let filter = EnvFilter::try_from_default_env()
459        .unwrap_or_else(|_| EnvFilter::new("info"));
460
461    if let Ok(log_file) = std::env::var("ALLFRAME_MCP_LOG_FILE") {
462        // Log to file
463        let file = std::fs::File::create(&log_file)
464            .expect("Failed to create log file");
465
466        tracing_subscriber::fmt()
467            .with_env_filter(filter)
468            .with_writer(file)
469            .with_ansi(false)
470            .init();
471    } else {
472        // Log to stderr
473        tracing_subscriber::fmt()
474            .with_env_filter(filter)
475            .with_writer(std::io::stderr)
476            .with_ansi(false)
477            .init();
478    }
479}
480
481#[cfg(not(feature = "tracing"))]
482pub fn init_tracing() {
483    // No-op when tracing feature is disabled
484}
485
486#[cfg(test)]
487mod tests {
488    use super::*;
489
490    #[test]
491    fn test_config_default() {
492        let config = StdioConfig::default();
493        assert_eq!(config.server_name, "allframe-mcp");
494        assert!(!config.include_debug_tool);
495    }
496
497    #[test]
498    fn test_config_builder() {
499        let config = StdioConfig::default()
500            .with_debug_tool(true)
501            .with_server_name("my-server")
502            .with_log_file("/tmp/mcp.log");
503
504        assert!(config.include_debug_tool);
505        assert_eq!(config.server_name, "my-server");
506        assert_eq!(config.log_file, Some("/tmp/mcp.log".to_string()));
507    }
508}