1use std::io::{stdin, stdout, BufRead, Write};
30use std::sync::atomic::{AtomicU64, Ordering};
31use std::time::Instant;
32
33use serde_json::{json, Value};
34
35use crate::McpServer;
36
37#[derive(Debug, Clone)]
39pub struct StdioConfig {
40 pub server_name: String,
42 pub server_version: String,
44 pub protocol_version: String,
46 pub include_debug_tool: bool,
48 pub log_file: Option<String>,
50}
51
52impl Default for StdioConfig {
53 fn default() -> Self {
54 Self {
55 server_name: "allframe-mcp".to_string(),
56 server_version: env!("CARGO_PKG_VERSION").to_string(),
57 protocol_version: "2024-11-05".to_string(),
58 include_debug_tool: false,
59 log_file: std::env::var("ALLFRAME_MCP_LOG_FILE").ok(),
60 }
61 }
62}
63
64impl StdioConfig {
65 pub fn with_debug_tool(mut self, enabled: bool) -> Self {
67 self.include_debug_tool = enabled;
68 self
69 }
70
71 pub fn with_server_name(mut self, name: impl Into<String>) -> Self {
73 self.server_name = name.into();
74 self
75 }
76
77 pub fn with_log_file(mut self, path: impl Into<String>) -> Self {
79 self.log_file = Some(path.into());
80 self
81 }
82}
83
84pub struct StdioTransport {
86 mcp: McpServer,
87 config: StdioConfig,
88 start_time: Instant,
89 request_count: AtomicU64,
90}
91
92impl StdioTransport {
93 pub fn new(mcp: McpServer, config: StdioConfig) -> Self {
95 Self {
96 mcp,
97 config,
98 start_time: Instant::now(),
99 request_count: AtomicU64::new(0),
100 }
101 }
102
103 pub async fn serve(self) {
105 self.log_startup();
106
107 let stdin = stdin();
108 let mut stdout = stdout();
109
110 let shutdown = async {
112 #[cfg(unix)]
113 {
114 use tokio::signal::unix::{signal, SignalKind};
115 let mut sigterm = signal(SignalKind::terminate()).ok();
116 let mut sigint = signal(SignalKind::interrupt()).ok();
117
118 tokio::select! {
119 _ = async { if let Some(ref mut s) = sigterm { s.recv().await } else { std::future::pending().await } } => {
120 self.log_info("Received SIGTERM");
121 }
122 _ = async { if let Some(ref mut s) = sigint { s.recv().await } else { std::future::pending().await } } => {
123 self.log_info("Received SIGINT");
124 }
125 }
126 }
127 #[cfg(not(unix))]
128 {
129 tokio::signal::ctrl_c().await.ok();
130 self.log_info("Received shutdown signal");
131 }
132 };
133
134 tokio::select! {
136 _ = self.run_loop(&stdin, &mut stdout) => {}
137 _ = shutdown => {
138 self.log_info("Shutting down gracefully");
139 }
140 }
141
142 self.log_shutdown();
143 }
144
145 async fn run_loop(&self, stdin: &std::io::Stdin, stdout: &mut std::io::Stdout) {
146 for line in stdin.lock().lines() {
147 let line = match line {
148 Ok(l) => l,
149 Err(e) => {
150 self.log_error(&format!("Error reading line: {}", e));
151 continue;
152 }
153 };
154
155 if line.trim().is_empty() {
157 continue;
158 }
159
160 self.request_count.fetch_add(1, Ordering::SeqCst);
161 let request_id = self.request_count.load(Ordering::SeqCst);
162
163 self.log_request(request_id, &line);
164
165 let request: Value = match serde_json::from_str(&line) {
167 Ok(r) => r,
168 Err(e) => {
169 self.log_error(&format!("Parse error: {}", e));
170 let error = json!({
171 "jsonrpc": "2.0",
172 "error": {
173 "code": -32700,
174 "message": "Parse error"
175 },
176 "id": null
177 });
178 self.write_response(stdout, &error, request_id);
179 continue;
180 }
181 };
182
183 let response = self.handle_request(request).await;
185
186 if let Some(resp) = response {
188 self.write_response(stdout, &resp, request_id);
189 }
190 }
191 }
192
193 fn write_response(&self, stdout: &mut std::io::Stdout, response: &Value, request_id: u64) {
194 match serde_json::to_string(&response) {
195 Ok(json_str) => {
196 self.log_response(request_id, &json_str);
197 if let Err(e) = writeln!(stdout, "{}", json_str) {
198 self.log_error(&format!("Error writing response: {}", e));
199 }
200 if let Err(e) = stdout.flush() {
201 self.log_error(&format!("Error flushing stdout: {}", e));
202 }
203 }
204 Err(e) => {
205 self.log_error(&format!("Error serializing response: {}", e));
206 }
207 }
208 }
209
210 async fn handle_request(&self, request: Value) -> Option<Value> {
211 let method = request["method"].as_str().unwrap_or("");
212 let id = request.get("id").cloned();
213
214 match method {
216 "initialized" | "notifications/initialized" => {
218 self.log_info("Client initialized connection");
219 return None;
220 }
221 "notifications/cancelled" => {
222 self.log_info("Request cancelled by client");
223 return None;
224 }
225 _ => {}
226 }
227
228 let result = match method {
229 "initialize" => {
231 self.log_info("Initializing MCP connection");
232 json!({
233 "protocolVersion": self.config.protocol_version,
234 "capabilities": {
235 "tools": {}
236 },
237 "serverInfo": {
238 "name": self.config.server_name,
239 "version": self.config.server_version
240 }
241 })
242 }
243
244 "tools/list" => {
246 let mut tools: Vec<Value> = self.mcp.list_tools().await.iter().map(|t| {
247 json!({
248 "name": t.name,
249 "description": t.description,
250 "inputSchema": serde_json::from_str::<Value>(&t.input_schema)
251 .unwrap_or_else(|_| json!({"type": "object"}))
252 })
253 }).collect();
254
255 if self.config.include_debug_tool {
257 tools.push(json!({
258 "name": "allframe/debug",
259 "description": "Get AllFrame MCP server diagnostics and status information",
260 "inputSchema": {
261 "type": "object",
262 "properties": {},
263 "additionalProperties": false
264 }
265 }));
266 }
267
268 json!({ "tools": tools })
269 }
270
271 "tools/call" => {
273 let params = &request["params"];
274 let name = params["name"].as_str().unwrap_or("");
275 let arguments = params.get("arguments").cloned().unwrap_or(json!({}));
276
277 self.log_info(&format!("Calling tool: {}", name));
278
279 if name == "allframe/debug" && self.config.include_debug_tool {
281 let diagnostics = self.get_diagnostics();
282 return Some(json!({
283 "jsonrpc": "2.0",
284 "result": {
285 "content": [{
286 "type": "text",
287 "text": serde_json::to_string_pretty(&diagnostics).unwrap()
288 }]
289 },
290 "id": id
291 }));
292 }
293
294 match self.mcp.call_tool(name, arguments).await {
295 Ok(result) => {
296 json!({
297 "content": [{
298 "type": "text",
299 "text": result.to_string()
300 }]
301 })
302 }
303 Err(e) => {
304 self.log_error(&format!("Tool error: {}", e));
305 json!({
306 "isError": true,
307 "content": [{
308 "type": "text",
309 "text": format!("Error: {}", e)
310 }]
311 })
312 }
313 }
314 }
315
316 "ping" => {
318 json!({})
319 }
320
321 _ => {
323 self.log_warn(&format!("Unknown method: {}", method));
324 return Some(json!({
325 "jsonrpc": "2.0",
326 "error": {
327 "code": -32601,
328 "message": format!("Method not found: {}", method)
329 },
330 "id": id
331 }));
332 }
333 };
334
335 Some(json!({
337 "jsonrpc": "2.0",
338 "result": result,
339 "id": id
340 }))
341 }
342
343 fn get_diagnostics(&self) -> Value {
344 json!({
345 "server": {
346 "name": self.config.server_name,
347 "version": self.config.server_version,
348 "protocol_version": self.config.protocol_version
349 },
350 "runtime": {
351 "uptime_seconds": self.start_time.elapsed().as_secs(),
352 "request_count": self.request_count.load(Ordering::SeqCst),
353 "tool_count": self.mcp.tool_count(),
354 "pid": std::process::id()
355 },
356 "build": {
357 "pkg_version": env!("CARGO_PKG_VERSION"),
358 "debug_tool_enabled": self.config.include_debug_tool
359 }
360 })
361 }
362
363 fn log_startup(&self) {
366 let msg = format!(
367 "MCP Server starting: name={}, version={}, pid={}, tools={}",
368 self.config.server_name,
369 self.config.server_version,
370 std::process::id(),
371 self.mcp.tool_count()
372 );
373
374 #[cfg(feature = "tracing")]
375 tracing::info!("{}", msg);
376
377 #[cfg(not(feature = "tracing"))]
378 eprintln!("[INFO] {}", msg);
379 }
380
381 fn log_shutdown(&self) {
382 let msg = format!(
383 "MCP Server shutting down: uptime={}s, requests={}",
384 self.start_time.elapsed().as_secs(),
385 self.request_count.load(Ordering::SeqCst)
386 );
387
388 #[cfg(feature = "tracing")]
389 tracing::info!("{}", msg);
390
391 #[cfg(not(feature = "tracing"))]
392 eprintln!("[INFO] {}", msg);
393 }
394
395 fn log_request(&self, id: u64, content: &str) {
396 let truncated = if content.len() > 500 {
398 format!("{}...(truncated)", &content[..500])
399 } else {
400 content.to_string()
401 };
402
403 #[cfg(feature = "tracing")]
404 tracing::debug!(request_id = id, request = %truncated, "Received MCP request");
405
406 #[cfg(not(feature = "tracing"))]
407 if std::env::var("ALLFRAME_MCP_DEBUG").is_ok() {
408 eprintln!("[DEBUG] req#{}: {}", id, truncated);
409 }
410 }
411
412 fn log_response(&self, id: u64, content: &str) {
413 let truncated = if content.len() > 500 {
414 format!("{}...(truncated)", &content[..500])
415 } else {
416 content.to_string()
417 };
418
419 #[cfg(feature = "tracing")]
420 tracing::debug!(request_id = id, response = %truncated, "Sending MCP response");
421
422 #[cfg(not(feature = "tracing"))]
423 if std::env::var("ALLFRAME_MCP_DEBUG").is_ok() {
424 eprintln!("[DEBUG] res#{}: {}", id, truncated);
425 }
426 }
427
428 fn log_info(&self, msg: &str) {
429 #[cfg(feature = "tracing")]
430 tracing::info!("{}", msg);
431
432 #[cfg(not(feature = "tracing"))]
433 eprintln!("[INFO] {}", msg);
434 }
435
436 fn log_warn(&self, msg: &str) {
437 #[cfg(feature = "tracing")]
438 tracing::warn!("{}", msg);
439
440 #[cfg(not(feature = "tracing"))]
441 eprintln!("[WARN] {}", msg);
442 }
443
444 fn log_error(&self, msg: &str) {
445 #[cfg(feature = "tracing")]
446 tracing::error!("{}", msg);
447
448 #[cfg(not(feature = "tracing"))]
449 eprintln!("[ERROR] {}", msg);
450 }
451}
452
453#[cfg(feature = "tracing")]
455pub fn init_tracing() {
456 use tracing_subscriber::EnvFilter;
457
458 let filter = EnvFilter::try_from_default_env()
459 .unwrap_or_else(|_| EnvFilter::new("info"));
460
461 if let Ok(log_file) = std::env::var("ALLFRAME_MCP_LOG_FILE") {
462 let file = std::fs::File::create(&log_file)
464 .expect("Failed to create log file");
465
466 tracing_subscriber::fmt()
467 .with_env_filter(filter)
468 .with_writer(file)
469 .with_ansi(false)
470 .init();
471 } else {
472 tracing_subscriber::fmt()
474 .with_env_filter(filter)
475 .with_writer(std::io::stderr)
476 .with_ansi(false)
477 .init();
478 }
479}
480
481#[cfg(not(feature = "tracing"))]
482pub fn init_tracing() {
483 }
485
486#[cfg(test)]
487mod tests {
488 use super::*;
489
490 #[test]
491 fn test_config_default() {
492 let config = StdioConfig::default();
493 assert_eq!(config.server_name, "allframe-mcp");
494 assert!(!config.include_debug_tool);
495 }
496
497 #[test]
498 fn test_config_builder() {
499 let config = StdioConfig::default()
500 .with_debug_tool(true)
501 .with_server_name("my-server")
502 .with_log_file("/tmp/mcp.log");
503
504 assert!(config.include_debug_tool);
505 assert_eq!(config.server_name, "my-server");
506 assert_eq!(config.log_file, Some("/tmp/mcp.log".to_string()));
507 }
508}