codex_memory/mcp_server/
mod.rs

1//! Minimal MCP server implementation
2pub mod handlers;
3pub mod tools;
4pub mod transport;
5
6// Re-export for tests
7pub use handlers::MCPHandlers;
8
9use crate::config::Config;
10use crate::error::Result;
11use crate::storage::Storage;
12use std::sync::Arc;
13use std::time::{Duration, Instant};
14use tokio::io::{AsyncReadExt, AsyncWriteExt};
15use tracing::{error, info, warn};
16
17/// Simple MCP server
18pub struct MCPServer {
19    _config: Config,
20    handlers: Arc<MCPHandlers>,
21    start_time: Instant,
22    last_request: Arc<std::sync::Mutex<Instant>>,
23}
24
25impl MCPServer {
26    /// Create a new MCP server
27    pub fn new(config: Config, storage: Arc<Storage>) -> Self {
28        let handlers = Arc::new(MCPHandlers::new(storage));
29        let now = Instant::now();
30        Self {
31            _config: config,
32            handlers,
33            start_time: now,
34            last_request: Arc::new(std::sync::Mutex::new(now)),
35        }
36    }
37
38    /// Check if server should self-terminate due to inactivity
39    fn should_terminate(&self) -> bool {
40        let last_request = *self.last_request.lock().unwrap();
41        let inactive_duration = last_request.elapsed();
42        
43        // Terminate if inactive for more than 10 minutes
44        if inactive_duration > Duration::from_secs(600) {
45            warn!("Server inactive for {:?}, initiating shutdown", inactive_duration);
46            return true;
47        }
48        
49        false
50    }
51
52    /// Update last request time
53    fn update_last_request(&self) {
54        *self.last_request.lock().unwrap() = Instant::now();
55    }
56
57    /// Log health status periodically
58    async fn health_monitor(&self) {
59        let mut interval = tokio::time::interval(Duration::from_secs(60)); // Every minute
60        
61        loop {
62            interval.tick().await;
63            
64            if self.should_terminate() {
65                error!("Health monitor detected inactivity timeout, terminating process");
66                std::process::exit(1);
67            }
68            
69            let uptime = self.start_time.elapsed();
70            let last_request_ago = self.last_request.lock().unwrap().elapsed();
71            
72            info!("Health check: uptime={:?}, last_request={:?} ago", uptime, last_request_ago);
73        }
74    }
75
76    /// Run in stdio mode for Claude Desktop
77    pub async fn run_stdio(&self) -> Result<()> {
78        info!("MCP server running in stdio mode");
79
80        // Spawn health monitor task
81        let health_monitor = {
82            let server_clone = Self {
83                _config: self._config.clone(),
84                handlers: Arc::clone(&self.handlers),
85                start_time: self.start_time,
86                last_request: Arc::clone(&self.last_request),
87            };
88            tokio::spawn(async move {
89                server_clone.health_monitor().await;
90            })
91        };
92
93        let stdin = tokio::io::stdin();
94        let stdout = tokio::io::stdout();
95        let mut stdin = stdin;
96        let mut stdout = stdout;
97
98        let mut buffer = String::new();
99        let mut temp_buf = [0u8; 8192]; // 8KB buffer for reading chunks
100
101        loop {
102            // Read chunk from stdin
103            match stdin.read(&mut temp_buf).await {
104                Ok(0) => {
105                    info!("Received EOF, shutting down MCP server");
106                    break; // EOF
107                }
108                Ok(n) => {
109                    // Convert bytes to string and append to buffer
110                    let chunk = String::from_utf8_lossy(&temp_buf[..n]);
111                    info!(
112                        "Read {} bytes: {:?}",
113                        n,
114                        &chunk[..std::cmp::min(100, chunk.len())]
115                    );
116                    buffer.push_str(&chunk);
117
118                    // Process complete JSON objects from buffer
119                    while let Some(json_end) = self.find_complete_json(&buffer) {
120                        let json_str = buffer[..json_end].trim().to_string();
121                        buffer.drain(..json_end);
122
123                        if !json_str.is_empty() {
124                            info!("Processing JSON request ({} chars)", json_str.len());
125                            self.update_last_request();
126                            let response = self.handle_request(&json_str).await;
127                            stdout.write_all(response.as_bytes()).await?;
128                            stdout.write_all(b"\n").await?;
129                            stdout.flush().await?;
130                        }
131                    }
132                }
133                Err(e) => {
134                    error!("Error reading input: {}", e);
135                    break;
136                }
137            }
138        }
139
140        info!("MCP server shutting down gracefully");
141        
142        // Cancel health monitor task
143        health_monitor.abort();
144        
145        // The pool will be dropped automatically when Storage is dropped
146        Ok(())
147    }
148
149    /// Find the end of a complete JSON object in the buffer
150    fn find_complete_json(&self, buffer: &str) -> Option<usize> {
151        let mut brace_count = 0;
152        let mut in_string = false;
153        let mut escape_next = false;
154        let mut start_found = false;
155
156        for (i, ch) in buffer.char_indices() {
157            if escape_next {
158                escape_next = false;
159                continue;
160            }
161
162            match ch {
163                '\\' if in_string => escape_next = true,
164                '"' => in_string = !in_string,
165                '{' if !in_string => {
166                    brace_count += 1;
167                    start_found = true;
168                }
169                '}' if !in_string => {
170                    brace_count -= 1;
171                    if start_found && brace_count == 0 {
172                        return Some(i + 1);
173                    }
174                }
175                _ => {}
176            }
177        }
178
179        None
180    }
181
182    async fn handle_request(&self, request: &str) -> String {
183        let request: serde_json::Value = match serde_json::from_str(request) {
184            Ok(v) => v,
185            Err(e) => {
186                return format!(
187                    r#"{{"jsonrpc":"2.0","id":null,"error":{{"code":-32700,"message":"Parse error: {}"}}}}"#,
188                    e
189                );
190            }
191        };
192
193        let method = request["method"].as_str().unwrap_or("");
194        let params = request.get("params").cloned().unwrap_or_default();
195        let id = request.get("id").cloned();
196
197        let result = match method {
198            "initialize" => Ok(serde_json::json!({
199                "protocolVersion": "2024-11-05",
200                "capabilities": {
201                    "tools": {}
202                },
203                "serverInfo": {
204                    "name": "codex-memory",
205                    "version": env!("CARGO_PKG_VERSION")
206                }
207            })),
208            "tools/list" => Ok(serde_json::json!({
209                "tools": tools::MCPTools::get_tools_list()
210            })),
211            "tools/call" => {
212                let tool_name = params["name"].as_str().unwrap_or("");
213                let tool_params = params.get("arguments").cloned().unwrap_or_default();
214                self.handlers.handle_tool_call(tool_name, tool_params).await
215            }
216            _ => Err(crate::error::Error::Other(format!(
217                "Unknown method: {}",
218                method
219            ))),
220        };
221
222        match result {
223            Ok(value) => {
224                if let Some(id) = id {
225                    format!(r#"{{"jsonrpc":"2.0","id":{},"result":{}}}"#, id, value)
226                } else {
227                    format!(r#"{{"jsonrpc":"2.0","result":{}}}"#, value)
228                }
229            }
230            Err(e) => {
231                if let Some(id) = id {
232                    format!(
233                        r#"{{"jsonrpc":"2.0","id":{},"error":{{"code":-32000,"message":"{}"}}}}"#,
234                        id, e
235                    )
236                } else {
237                    format!(
238                        r#"{{"jsonrpc":"2.0","id":null,"error":{{"code":-32000,"message":"{}"}}}}"#,
239                        e
240                    )
241                }
242            }
243        }
244    }
245}