codex_memory/mcp_server/
mod.rs

1//! Minimal MCP server implementation
2pub mod handlers;
3pub mod tools;
4pub mod transport;
5
6// Re-export for tests
7pub use handlers::MCPHandlers;
8
9use crate::config::Config;
10use crate::error::Result;
11use crate::storage::Storage;
12use std::sync::Arc;
13use std::time::{Duration, Instant};
14use tokio::io::AsyncWriteExt;
15use tokio_util::codec::{Decoder, FramedRead};
16use futures_util::StreamExt;
17use tracing::{error, info, warn};
18
19/// Simple MCP server
20pub struct MCPServer {
21    _config: Config,
22    handlers: Arc<MCPHandlers>,
23    start_time: Instant,
24    last_request: Arc<std::sync::Mutex<Instant>>,
25}
26
27impl MCPServer {
28    /// Create a new MCP server
29    pub fn new(config: Config, storage: Arc<Storage>) -> Self {
30        let handlers = Arc::new(MCPHandlers::new(storage));
31        let now = Instant::now();
32        Self {
33            _config: config,
34            handlers,
35            start_time: now,
36            last_request: Arc::new(std::sync::Mutex::new(now)),
37        }
38    }
39
40    /// Check if server should self-terminate due to inactivity
41    fn should_terminate(&self) -> bool {
42        let last_request = *self.last_request.lock().unwrap();
43        let inactive_duration = last_request.elapsed();
44
45        // Terminate if inactive for more than 24 hours (Claude Desktop manages restarts)
46        if inactive_duration > Duration::from_secs(86400) {
47            warn!(
48                "Server inactive for {:?}, initiating shutdown",
49                inactive_duration
50            );
51            return true;
52        }
53
54        false
55    }
56
57    /// Update last request time
58    fn update_last_request(&self) {
59        *self.last_request.lock().unwrap() = Instant::now();
60    }
61
62    /// Log health status periodically
63    async fn health_monitor(&self) {
64        let mut interval = tokio::time::interval(Duration::from_secs(60)); // Every minute
65
66        loop {
67            interval.tick().await;
68
69            if self.should_terminate() {
70                error!("Health monitor detected inactivity timeout, terminating process");
71                std::process::exit(1);
72            }
73
74            let uptime = self.start_time.elapsed();
75            let last_request_ago = self.last_request.lock().unwrap().elapsed();
76
77            info!(
78                "Health check: uptime={:?}, last_request={:?} ago",
79                uptime, last_request_ago
80            );
81        }
82    }
83
84    /// Run in stdio mode for Claude Desktop using secure JSON streaming
85    pub async fn run_stdio(&self) -> Result<()> {
86        info!("MCP server running in stdio mode with secure JSON streaming");
87
88        // Spawn health monitor task
89        let health_monitor = {
90            let server_clone = Self {
91                _config: self._config.clone(),
92                handlers: Arc::clone(&self.handlers),
93                start_time: self.start_time,
94                last_request: Arc::clone(&self.last_request),
95            };
96            tokio::spawn(async move {
97                server_clone.health_monitor().await;
98            })
99        };
100
101        let stdin = tokio::io::stdin();
102        let stdout = tokio::io::stdout();
103        let mut stdout = stdout;
104
105        // Use secure streaming JSON decoder with buffer limits
106        let mut framed = FramedRead::new(stdin, SecureJsonDecoder::new());
107
108        loop {
109            tokio::select! {
110                // Process incoming JSON with timeout protection
111                message = framed.next() => {
112                    match message {
113                        Some(Ok(json_str)) => {
114                            info!("Processing JSON request ({} chars)", json_str.len());
115                            self.update_last_request();
116                            let response = self.handle_request(&json_str).await;
117                            if !response.is_empty() {
118                                stdout.write_all(response.as_bytes()).await?;
119                                stdout.write_all(b"\n").await?;
120                                stdout.flush().await?;
121                            }
122                        }
123                        Some(Err(e)) => {
124                            // Don't treat "bytes remaining on stream" as an error - it's normal EOF
125                            let error_msg = e.to_string();
126                            if !error_msg.contains("bytes remaining on stream") {
127                                error!("JSON decode error: {}", e);
128                                // Send error response back to client
129                                let parse_error = crate::error::Error::ParseError(e.to_string());
130                                let error_response = parse_error.to_json_rpc_error(None);
131                                stdout.write_all(serde_json::to_string(&error_response).unwrap().as_bytes()).await?;
132                                stdout.flush().await?;
133                            }
134                        }
135                        None => {
136                            info!("Received EOF, shutting down MCP server");
137                            break;
138                        }
139                    }
140                }
141                // CODEX-MCP-004: Timeout protection for requests (default 60s from Architecture spec)
142                _ = tokio::time::sleep(Duration::from_secs(60)) => {
143                    if self.should_terminate() {
144                        warn!("MCP server inactive for too long, initiating graceful shutdown");
145                        break;
146                    }
147                }
148            }
149        }
150
151        info!("MCP server shutting down gracefully");
152
153        // Cancel health monitor task
154        health_monitor.abort();
155
156        Ok(())
157    }
158
159    // SECURITY: Removed vulnerable find_complete_json() function
160    // Replaced with secure serde_json streaming in SecureJsonDecoder
161
162    pub async fn handle_request(&self, request: &str) -> String {
163        // Add detailed logging for debugging JSON parsing issues
164        info!("Raw request to parse: {:?}", request);
165
166        let request: serde_json::Value = match serde_json::from_str(request) {
167            Ok(v) => v,
168            Err(e) => {
169                error!("JSON parse error: {} - Request: {:?}", e, request);
170                let parse_error = crate::error::Error::ParseError(e.to_string());
171                return serde_json::to_string(&parse_error.to_json_rpc_error(Some(serde_json::json!(0))))
172                    .unwrap_or_else(|_| r#"{"jsonrpc":"2.0","id":0,"error":{"code":-32700,"message":"Parse error"}}"#.to_string());
173            }
174        };
175
176        // CODEX-MCP-002: Validate JSON-RPC request structure
177        let method = request["method"].as_str().unwrap_or("");
178        if method.is_empty() {
179            let invalid_request_error = crate::error::Error::InvalidRequest("Missing 'method' field".to_string());
180            return serde_json::to_string(&invalid_request_error.to_json_rpc_error(request.get("id").cloned()))
181                .unwrap_or_else(|_| r#"{"jsonrpc":"2.0","id":null,"error":{"code":-32600,"message":"Invalid Request"}}"#.to_string());
182        }
183
184        let params = request.get("params").cloned().unwrap_or_default();
185        let id = request.get("id").cloned();
186
187        let result = match method {
188            "initialize" => Ok(serde_json::json!({
189                "protocolVersion": "2024-11-05",
190                "capabilities": {
191                    "tools": {}
192                },
193                "serverInfo": {
194                    "name": "codex-memory",
195                    "version": env!("CARGO_PKG_VERSION")
196                }
197            })),
198            "tools/list" => Ok(serde_json::json!({
199                "tools": tools::MCPTools::get_tools_list()
200            })),
201            "tools/call" => {
202                let tool_name = params["name"].as_str().unwrap_or("");
203                let tool_params = params.get("arguments").cloned().unwrap_or_default();
204                
205                // CODEX-MCP-004: Add timeout handling for tool calls (default 60s from Architecture spec)
206                let timeout_duration = std::time::Duration::from_secs(60);
207                
208                match tokio::time::timeout(timeout_duration, 
209                    self.handlers.handle_tool_call(tool_name, tool_params)
210                ).await {
211                    Ok(result) => result,
212                    Err(_) => Err(crate::error::Error::Timeout(format!(
213                        "Tool call '{}' timed out after {} seconds", 
214                        tool_name, 
215                        timeout_duration.as_secs()
216                    )))
217                }
218            }
219            "prompts/list" => {
220                // Return empty prompts list (we don't support prompts)
221                Ok(serde_json::json!({
222                    "prompts": []
223                }))
224            }
225            "resources/list" => {
226                // Return empty resources list (we don't support resources)
227                Ok(serde_json::json!({
228                    "resources": []
229                }))
230            }
231            "notifications/initialized" => {
232                // Notifications don't require responses, just acknowledge silently
233                return "".to_string(); // Return empty string for notifications
234            }
235            _ => {
236                // CODEX-MCP-002: Use proper JSON-RPC error code for unknown methods
237                Err(crate::error::Error::MethodNotFound(format!(
238                    "Unknown method: {}. Supported methods: initialize, tools/list, tools/call, prompts/list, resources/list, notifications/initialized",
239                    method
240                )))
241            }
242        };
243
244        match result {
245            Ok(value) => {
246                if let Some(id) = id {
247                    format!(r#"{{"jsonrpc":"2.0","id":{},"result":{}}}"#, id, value)
248                } else {
249                    format!(r#"{{"jsonrpc":"2.0","result":{}}}"#, value)
250                }
251            }
252            Err(e) => {
253                // Log the error for debugging connection failures
254                error!("MCP request failed - Method: {}, Error: {}", method, e);
255
256                // CODEX-MCP-002: Use JSON-RPC 2.0 compliant error responses with proper error codes
257                let error_response = e.to_json_rpc_error(id.or_else(|| Some(serde_json::json!(0))));
258                serde_json::to_string(&error_response)
259                    .unwrap_or_else(|_| r#"{"jsonrpc":"2.0","id":0,"error":{"code":-32603,"message":"Internal error"}}"#.to_string())
260            }
261        }
262    }
263}
264
265/// Secure JSON decoder with buffer size limits and attack protection
266struct SecureJsonDecoder {
267    /// Maximum buffer size to prevent memory exhaustion attacks (10MB)
268    max_buffer_size: usize,
269}
270
271impl SecureJsonDecoder {
272    fn new() -> Self {
273        Self {
274            max_buffer_size: 10 * 1024 * 1024, // 10MB limit per Architecture spec
275        }
276    }
277}
278
279impl Decoder for SecureJsonDecoder {
280    type Item = String;
281    type Error = std::io::Error;
282
283    fn decode(&mut self, src: &mut bytes::BytesMut) -> std::result::Result<Option<Self::Item>, Self::Error> {
284        // SECURITY: Enforce buffer size limits to prevent memory exhaustion attacks
285        if src.len() > self.max_buffer_size {
286            return Err(std::io::Error::new(
287                std::io::ErrorKind::InvalidData,
288                format!(
289                    "Buffer size limit exceeded: {} bytes (max: {})",
290                    src.len(),
291                    self.max_buffer_size
292                ),
293            ));
294        }
295
296        // Convert buffer to string with strict UTF-8 validation (replaces lossy conversion)
297        match std::str::from_utf8(src) {
298            Ok(_) => {}, // Valid UTF-8, continue processing
299            Err(_) => {
300                return Err(std::io::Error::new(
301                    std::io::ErrorKind::InvalidData,
302                    "Invalid UTF-8 encoding in JSON stream",
303                ));
304            }
305        };
306
307        // SECURITY: Use secure serde_json streaming parser instead of custom parser
308        let mut depth = 0;
309        let mut in_string = false;
310        let mut escape_next = false;
311        let mut json_start = None;
312        
313        for (i, byte) in src.iter().enumerate() {
314            let ch = *byte as char;
315            
316            if escape_next {
317                escape_next = false;
318                continue;
319            }
320
321            match ch {
322                '\\' if in_string => escape_next = true,
323                '"' => in_string = !in_string,
324                '{' if !in_string => {
325                    if json_start.is_none() {
326                        json_start = Some(i);
327                    }
328                    depth += 1;
329                    // SECURITY: Limit recursion depth to prevent stack overflow attacks
330                    if depth > 100 {
331                        return Err(std::io::Error::new(
332                            std::io::ErrorKind::InvalidData,
333                            "JSON nesting depth exceeded (max: 100 levels)",
334                        ));
335                    }
336                }
337                '}' if !in_string => {
338                    depth -= 1;
339                    if depth == 0 && json_start.is_some() {
340                        // Found complete JSON object - efficient zero-copy extraction
341                        let json_bytes = src.split_to(i + 1);
342                        
343                        // PERFORMANCE: Use strict UTF-8 validation without lossy conversion (CODEX-MCP-011)
344                        let json_str = match std::str::from_utf8(&json_bytes) {
345                            Ok(s) => s.to_string(),
346                            Err(e) => {
347                                return Err(std::io::Error::new(
348                                    std::io::ErrorKind::InvalidData,
349                                    format!("Invalid UTF-8 in JSON: {}", e),
350                                ));
351                            }
352                        };
353                        
354                        // SECURITY: Validate JSON using serde_json before processing
355                        match serde_json::from_str::<serde_json::Value>(&json_str) {
356                            Ok(_) => return Ok(Some(json_str)),
357                            Err(e) => {
358                                return Err(std::io::Error::new(
359                                    std::io::ErrorKind::InvalidData,
360                                    format!("Invalid JSON structure: {}", e),
361                                ));
362                            }
363                        }
364                    }
365                }
366                _ => {}
367            }
368        }
369
370        // No complete JSON object found yet
371        Ok(None)
372    }
373}