codex_memory/mcp_server/
mod.rs

1//! Model Context Protocol (MCP) Server Implementation
2//!
3//! This module provides a complete MCP server implementation that follows
4//! the MCP protocol specification 2025-06-18 using stdio transport.
5//!
6//! The server exposes memory management tools through the MCP protocol,
7//! allowing Claude to store, search, and manage hierarchical memories.
8
9pub mod auth;
10pub mod circuit_breaker;
11pub mod handlers;
12pub mod rate_limiter;
13pub mod tools;
14pub mod transport;
15
16pub use auth::{AuthContext, AuthMethod, MCPAuth, MCPAuthConfig};
17pub use circuit_breaker::{
18    CircuitBreaker, CircuitBreakerConfig, CircuitBreakerStats, CircuitState,
19};
20pub use handlers::MCPHandlers;
21pub use rate_limiter::{MCPRateLimitConfig, MCPRateLimiter, RateLimitStats};
22pub use tools::MCPTools;
23pub use transport::StdioTransport;
24
25use crate::memory::{
26    ImportanceAssessmentConfig, ImportanceAssessmentPipeline, MemoryRepository,
27    SilentHarvesterService,
28};
29use crate::security::{audit::AuditLogger, AuditConfig};
30use crate::SimpleEmbedder;
31use anyhow::Result;
32use std::sync::Arc;
33use tracing::info;
34
35/// MCP Server configuration
36#[derive(Clone, Debug)]
37pub struct MCPServerConfig {
38    pub request_timeout_ms: u64,
39    pub max_request_size: usize,
40    pub enable_circuit_breaker: bool,
41    pub circuit_breaker: CircuitBreakerConfig,
42    pub enable_authentication: bool,
43    pub auth: MCPAuthConfig,
44    pub enable_rate_limiting: bool,
45    pub rate_limiting: MCPRateLimitConfig,
46    pub audit: AuditConfig,
47}
48
49impl Default for MCPServerConfig {
50    fn default() -> Self {
51        Self {
52            request_timeout_ms: 30000,
53            max_request_size: 10 * 1024 * 1024, // 10MB
54            enable_circuit_breaker: std::env::var("MCP_CIRCUIT_BREAKER_ENABLED")
55                .map(|s| s.parse().unwrap_or(true))
56                .unwrap_or(true),
57            circuit_breaker: CircuitBreakerConfig::default(),
58            enable_authentication: std::env::var("MCP_AUTH_ENABLED")
59                .map(|s| s.parse().unwrap_or(true))
60                .unwrap_or(true),
61            auth: MCPAuthConfig::from_env(),
62            enable_rate_limiting: std::env::var("MCP_RATE_LIMIT_ENABLED")
63                .map(|s| s.parse().unwrap_or(true))
64                .unwrap_or(true),
65            rate_limiting: MCPRateLimitConfig::from_env(),
66            audit: AuditConfig::default(),
67        }
68    }
69}
70
71/// Main MCP Server implementation
72pub struct MCPServer {
73    config: MCPServerConfig,
74    repository: Arc<MemoryRepository>,
75    embedder: Arc<SimpleEmbedder>,
76    handlers: MCPHandlers,
77    transport: StdioTransport,
78    circuit_breaker: Option<Arc<CircuitBreaker>>,
79    harvester_service: Arc<SilentHarvesterService>,
80    auth: Option<Arc<MCPAuth>>,
81    rate_limiter: Option<Arc<MCPRateLimiter>>,
82    audit_logger: Arc<AuditLogger>,
83}
84
85impl MCPServer {
86    /// Create a new MCP server instance
87    pub fn new(
88        repository: Arc<MemoryRepository>,
89        embedder: Arc<SimpleEmbedder>,
90        config: MCPServerConfig,
91    ) -> Result<Self> {
92        // Initialize audit logger
93        let audit_logger = Arc::new(AuditLogger::new(config.audit.clone())?);
94
95        // Initialize authentication if enabled
96        let auth = if config.enable_authentication {
97            Some(Arc::new(MCPAuth::new(
98                config.auth.clone(),
99                audit_logger.clone(),
100            )?))
101        } else {
102            None
103        };
104
105        // Initialize rate limiting if enabled
106        let rate_limiter = if config.enable_rate_limiting {
107            Some(Arc::new(MCPRateLimiter::new(
108                config.rate_limiting.clone(),
109                audit_logger.clone(),
110            )))
111        } else {
112            None
113        };
114        // Initialize Silent Harvester Service
115        let importance_config = ImportanceAssessmentConfig::default();
116        let importance_pipeline = Arc::new(ImportanceAssessmentPipeline::new(
117            importance_config,
118            embedder.clone(),
119            prometheus::default_registry(),
120        )?);
121
122        let harvester_service = Arc::new(SilentHarvesterService::new(
123            repository.clone(),
124            importance_pipeline,
125            embedder.clone(),
126            None, // Use default config
127            prometheus::default_registry(),
128        )?);
129
130        // Initialize circuit breaker if enabled
131        let circuit_breaker = if config.enable_circuit_breaker {
132            Some(Arc::new(CircuitBreaker::new(
133                config.circuit_breaker.clone(),
134            )))
135        } else {
136            None
137        };
138
139        // Create handlers
140        let handlers = MCPHandlers::new(
141            repository.clone(),
142            embedder.clone(),
143            harvester_service.clone(),
144            circuit_breaker.clone(),
145            auth.clone(),
146            rate_limiter.clone(),
147        );
148
149        // Create transport
150        let transport = StdioTransport::new(config.request_timeout_ms);
151
152        Ok(Self {
153            config,
154            repository,
155            embedder,
156            handlers,
157            transport,
158            circuit_breaker,
159            harvester_service,
160            auth,
161            rate_limiter,
162            audit_logger,
163        })
164    }
165
166    /// Start the MCP server
167    pub async fn start(&mut self) -> Result<()> {
168        info!("Starting MCP server with stdio transport");
169        info!("Protocol version: 2025-06-18");
170        info!("Capabilities: tools");
171
172        // Start the transport layer
173        self.transport.start(&mut self.handlers).await
174    }
175
176    /// Get server statistics
177    pub async fn get_stats(&self) -> Result<serde_json::Value> {
178        let repo_stats = self.repository.get_statistics().await?;
179        let harvester_metrics = self.harvester_service.get_metrics().await;
180
181        let circuit_breaker_stats = if let Some(ref cb) = self.circuit_breaker {
182            Some(cb.get_stats().await)
183        } else {
184            None
185        };
186
187        let auth_stats = if let Some(ref auth) = self.auth {
188            Some(auth.get_stats().await)
189        } else {
190            None
191        };
192
193        let rate_limit_stats = if let Some(ref rl) = self.rate_limiter {
194            Some(rl.get_status().await)
195        } else {
196            None
197        };
198
199        Ok(serde_json::json!({
200            "server": {
201                "protocol_version": "2025-06-18",
202                "transport": "stdio",
203                "uptime_seconds": std::time::SystemTime::now()
204                    .duration_since(std::time::UNIX_EPOCH)
205                    .unwrap_or_default()
206                    .as_secs(),
207                "security": {
208                    "authentication_enabled": self.config.enable_authentication,
209                    "rate_limiting_enabled": self.config.enable_rate_limiting,
210                }
211            },
212            "memory_system": repo_stats,
213            "harvester": harvester_metrics,
214            "circuit_breaker": circuit_breaker_stats,
215            "authentication": auth_stats,
216            "rate_limiting": rate_limit_stats
217        }))
218    }
219
220    /// Shutdown the server gracefully
221    pub async fn shutdown(&mut self) -> Result<()> {
222        info!("Shutting down MCP server");
223
224        // Any cleanup logic here
225        if let Some(ref cb) = self.circuit_breaker {
226            cb.reset().await;
227        }
228
229        Ok(())
230    }
231}