codex-memory 3.0.15

A simple memory storage service with MCP interface for Claude Desktop
Documentation
//! Minimal MCP server implementation
pub mod handlers;
pub mod tools;
pub mod transport;

// Re-export for tests
pub use handlers::MCPHandlers;

use crate::config::Config;
use crate::error::Result;
use crate::storage::Storage;
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::io::AsyncWriteExt;
use tokio_util::codec::{Decoder, FramedRead};
use futures_util::StreamExt;
use tracing::{error, info, warn};

/// Simple MCP server
pub struct MCPServer {
    _config: Config,
    handlers: Arc<MCPHandlers>,
    start_time: Instant,
    last_request: Arc<std::sync::Mutex<Instant>>,
}

impl MCPServer {
    /// Create a new MCP server
    pub fn new(config: Config, storage: Arc<Storage>) -> Self {
        let handlers = Arc::new(MCPHandlers::new(storage));
        let now = Instant::now();
        Self {
            _config: config,
            handlers,
            start_time: now,
            last_request: Arc::new(std::sync::Mutex::new(now)),
        }
    }

    /// Check if server should self-terminate due to inactivity
    fn should_terminate(&self) -> bool {
        let last_request = *self.last_request.lock().unwrap();
        let inactive_duration = last_request.elapsed();

        // Terminate if inactive for more than 24 hours (Claude Desktop manages restarts)
        if inactive_duration > Duration::from_secs(86400) {
            warn!(
                "Server inactive for {:?}, initiating shutdown",
                inactive_duration
            );
            return true;
        }

        false
    }

    /// Update last request time
    fn update_last_request(&self) {
        *self.last_request.lock().unwrap() = Instant::now();
    }

    /// Log health status periodically
    async fn health_monitor(&self) {
        let mut interval = tokio::time::interval(Duration::from_secs(60)); // Every minute

        loop {
            interval.tick().await;

            if self.should_terminate() {
                error!("Health monitor detected inactivity timeout, terminating process");
                std::process::exit(1);
            }

            let uptime = self.start_time.elapsed();
            let last_request_ago = self.last_request.lock().unwrap().elapsed();

            info!(
                "Health check: uptime={:?}, last_request={:?} ago",
                uptime, last_request_ago
            );
        }
    }

    /// Run in stdio mode for Claude Desktop using secure JSON streaming
    pub async fn run_stdio(&self) -> Result<()> {
        info!("MCP server running in stdio mode with secure JSON streaming");

        // Spawn health monitor task
        let health_monitor = {
            let server_clone = Self {
                _config: self._config.clone(),
                handlers: Arc::clone(&self.handlers),
                start_time: self.start_time,
                last_request: Arc::clone(&self.last_request),
            };
            tokio::spawn(async move {
                server_clone.health_monitor().await;
            })
        };

        let stdin = tokio::io::stdin();
        let stdout = tokio::io::stdout();
        let mut stdout = stdout;

        // Use secure streaming JSON decoder with buffer limits
        let mut framed = FramedRead::new(stdin, SecureJsonDecoder::new());

        loop {
            tokio::select! {
                // Process incoming JSON with timeout protection
                message = framed.next() => {
                    match message {
                        Some(Ok(json_str)) => {
                            info!("Processing JSON request ({} chars)", json_str.len());
                            self.update_last_request();
                            let response = self.handle_request(&json_str).await;
                            if !response.is_empty() {
                                stdout.write_all(response.as_bytes()).await?;
                                stdout.write_all(b"\n").await?;
                                stdout.flush().await?;
                            }
                        }
                        Some(Err(e)) => {
                            // Don't treat "bytes remaining on stream" as an error - it's normal EOF
                            let error_msg = e.to_string();
                            if !error_msg.contains("bytes remaining on stream") {
                                error!("JSON decode error: {}", e);
                                // Send error response back to client
                                let parse_error = crate::error::Error::ParseError(e.to_string());
                                let error_response = parse_error.to_json_rpc_error(None);
                                stdout.write_all(serde_json::to_string(&error_response).unwrap().as_bytes()).await?;
                                stdout.flush().await?;
                            }
                        }
                        None => {
                            info!("Received EOF, shutting down MCP server");
                            break;
                        }
                    }
                }
                // CODEX-MCP-004: Timeout protection for requests (default 60s from Architecture spec)
                _ = tokio::time::sleep(Duration::from_secs(60)) => {
                    if self.should_terminate() {
                        warn!("MCP server inactive for too long, initiating graceful shutdown");
                        break;
                    }
                }
            }
        }

        info!("MCP server shutting down gracefully");

        // Cancel health monitor task
        health_monitor.abort();

        Ok(())
    }

    // SECURITY: Removed vulnerable find_complete_json() function
    // Replaced with secure serde_json streaming in SecureJsonDecoder

    pub async fn handle_request(&self, request: &str) -> String {
        // Add detailed logging for debugging JSON parsing issues
        info!("Raw request to parse: {:?}", request);

        let request: serde_json::Value = match serde_json::from_str(request) {
            Ok(v) => v,
            Err(e) => {
                error!("JSON parse error: {} - Request: {:?}", e, request);
                let parse_error = crate::error::Error::ParseError(e.to_string());
                return serde_json::to_string(&parse_error.to_json_rpc_error(Some(serde_json::json!(0))))
                    .unwrap_or_else(|_| r#"{"jsonrpc":"2.0","id":0,"error":{"code":-32700,"message":"Parse error"}}"#.to_string());
            }
        };

        // CODEX-MCP-002: Validate JSON-RPC request structure
        let method = request["method"].as_str().unwrap_or("");
        if method.is_empty() {
            let invalid_request_error = crate::error::Error::InvalidRequest("Missing 'method' field".to_string());
            return serde_json::to_string(&invalid_request_error.to_json_rpc_error(request.get("id").cloned()))
                .unwrap_or_else(|_| r#"{"jsonrpc":"2.0","id":null,"error":{"code":-32600,"message":"Invalid Request"}}"#.to_string());
        }

        let params = request.get("params").cloned().unwrap_or_default();
        let id = request.get("id").cloned();

        let result = match method {
            "initialize" => Ok(serde_json::json!({
                "protocolVersion": "2024-11-05",
                "capabilities": {
                    "tools": {}
                },
                "serverInfo": {
                    "name": "codex-memory",
                    "version": env!("CARGO_PKG_VERSION")
                }
            })),
            "tools/list" => Ok(serde_json::json!({
                "tools": tools::MCPTools::get_tools_list()
            })),
            "tools/call" => {
                let tool_name = params["name"].as_str().unwrap_or("");
                let tool_params = params.get("arguments").cloned().unwrap_or_default();
                
                // CODEX-MCP-004: Add timeout handling for tool calls (default 60s from Architecture spec)
                let timeout_duration = std::time::Duration::from_secs(60);
                
                match tokio::time::timeout(timeout_duration, 
                    self.handlers.handle_tool_call(tool_name, tool_params)
                ).await {
                    Ok(result) => result,
                    Err(_) => Err(crate::error::Error::Timeout(format!(
                        "Tool call '{}' timed out after {} seconds", 
                        tool_name, 
                        timeout_duration.as_secs()
                    )))
                }
            }
            "prompts/list" => {
                // Return empty prompts list (we don't support prompts)
                Ok(serde_json::json!({
                    "prompts": []
                }))
            }
            "resources/list" => {
                // Return empty resources list (we don't support resources)
                Ok(serde_json::json!({
                    "resources": []
                }))
            }
            "notifications/initialized" => {
                // Notifications don't require responses, just acknowledge silently
                return "".to_string(); // Return empty string for notifications
            }
            _ => {
                // CODEX-MCP-002: Use proper JSON-RPC error code for unknown methods
                Err(crate::error::Error::MethodNotFound(format!(
                    "Unknown method: {}. Supported methods: initialize, tools/list, tools/call, prompts/list, resources/list, notifications/initialized",
                    method
                )))
            }
        };

        match result {
            Ok(value) => {
                if let Some(id) = id {
                    format!(r#"{{"jsonrpc":"2.0","id":{},"result":{}}}"#, id, value)
                } else {
                    format!(r#"{{"jsonrpc":"2.0","result":{}}}"#, value)
                }
            }
            Err(e) => {
                // Log the error for debugging connection failures
                error!("MCP request failed - Method: {}, Error: {}", method, e);

                // CODEX-MCP-002: Use JSON-RPC 2.0 compliant error responses with proper error codes
                let error_response = e.to_json_rpc_error(id.or_else(|| Some(serde_json::json!(0))));
                serde_json::to_string(&error_response)
                    .unwrap_or_else(|_| r#"{"jsonrpc":"2.0","id":0,"error":{"code":-32603,"message":"Internal error"}}"#.to_string())
            }
        }
    }
}

/// Secure JSON decoder with buffer size limits and attack protection
struct SecureJsonDecoder {
    /// Maximum buffer size to prevent memory exhaustion attacks (10MB)
    max_buffer_size: usize,
}

impl SecureJsonDecoder {
    fn new() -> Self {
        Self {
            max_buffer_size: 10 * 1024 * 1024, // 10MB limit per Architecture spec
        }
    }
}

impl Decoder for SecureJsonDecoder {
    type Item = String;
    type Error = std::io::Error;

    fn decode(&mut self, src: &mut bytes::BytesMut) -> std::result::Result<Option<Self::Item>, Self::Error> {
        // SECURITY: Enforce buffer size limits to prevent memory exhaustion attacks
        if src.len() > self.max_buffer_size {
            return Err(std::io::Error::new(
                std::io::ErrorKind::InvalidData,
                format!(
                    "Buffer size limit exceeded: {} bytes (max: {})",
                    src.len(),
                    self.max_buffer_size
                ),
            ));
        }

        // Convert buffer to string with strict UTF-8 validation (replaces lossy conversion)
        match std::str::from_utf8(src) {
            Ok(_) => {}, // Valid UTF-8, continue processing
            Err(_) => {
                return Err(std::io::Error::new(
                    std::io::ErrorKind::InvalidData,
                    "Invalid UTF-8 encoding in JSON stream",
                ));
            }
        };

        // SECURITY: Use secure serde_json streaming parser instead of custom parser
        let mut depth = 0;
        let mut in_string = false;
        let mut escape_next = false;
        let mut json_start = None;
        
        for (i, byte) in src.iter().enumerate() {
            let ch = *byte as char;
            
            if escape_next {
                escape_next = false;
                continue;
            }

            match ch {
                '\\' if in_string => escape_next = true,
                '"' => in_string = !in_string,
                '{' if !in_string => {
                    if json_start.is_none() {
                        json_start = Some(i);
                    }
                    depth += 1;
                    // SECURITY: Limit recursion depth to prevent stack overflow attacks
                    if depth > 100 {
                        return Err(std::io::Error::new(
                            std::io::ErrorKind::InvalidData,
                            "JSON nesting depth exceeded (max: 100 levels)",
                        ));
                    }
                }
                '}' if !in_string => {
                    depth -= 1;
                    if depth == 0 && json_start.is_some() {
                        // Found complete JSON object - efficient zero-copy extraction
                        let json_bytes = src.split_to(i + 1);
                        
                        // PERFORMANCE: Use strict UTF-8 validation without lossy conversion (CODEX-MCP-011)
                        let json_str = match std::str::from_utf8(&json_bytes) {
                            Ok(s) => s.to_string(),
                            Err(e) => {
                                return Err(std::io::Error::new(
                                    std::io::ErrorKind::InvalidData,
                                    format!("Invalid UTF-8 in JSON: {}", e),
                                ));
                            }
                        };
                        
                        // SECURITY: Validate JSON using serde_json before processing
                        match serde_json::from_str::<serde_json::Value>(&json_str) {
                            Ok(_) => return Ok(Some(json_str)),
                            Err(e) => {
                                return Err(std::io::Error::new(
                                    std::io::ErrorKind::InvalidData,
                                    format!("Invalid JSON structure: {}", e),
                                ));
                            }
                        }
                    }
                }
                _ => {}
            }
        }

        // No complete JSON object found yet
        Ok(None)
    }
}