codex_memory/mcp/
server.rs

1use crate::memory::{models::*, MemoryRepository};
2use crate::monitoring::{AlertManager, HealthChecker, MetricsCollector, PerformanceProfiler};
3use crate::SimpleEmbedder;
4use anyhow::Result;
5use jsonrpc_core::{IoHandler, Params, Value};
6use jsonrpc_tcp_server::ServerBuilder;
7use serde::{Deserialize, Serialize};
8use std::net::SocketAddr;
9use std::sync::Arc;
10use tracing::info;
11use uuid::Uuid;
12
13#[derive(Clone)]
14pub struct McpServerConfig {
15    pub tcp_addr: SocketAddr,
16    pub unix_socket_path: Option<String>,
17    pub max_connections: u32,
18    pub request_timeout_ms: u64,
19    pub max_request_size: usize,
20    pub enable_compression: bool,
21}
22
23impl Default for McpServerConfig {
24    fn default() -> Self {
25        Self {
26            tcp_addr: ([127, 0, 0, 1], 3333).into(),
27            unix_socket_path: None,
28            max_connections: 1000,
29            request_timeout_ms: 30000,
30            max_request_size: 10 * 1024 * 1024, // 10MB
31            enable_compression: true,
32        }
33    }
34}
35
36pub struct MCPServer {
37    repository: Arc<MemoryRepository>,
38    embedder: Arc<SimpleEmbedder>,
39    metrics: Arc<MetricsCollector>,
40    health_checker: Arc<HealthChecker>,
41    alert_manager: Arc<tokio::sync::RwLock<AlertManager>>,
42    profiler: Arc<PerformanceProfiler>,
43}
44
45impl MCPServer {
46    pub fn new(repository: Arc<MemoryRepository>, embedder: Arc<SimpleEmbedder>) -> Result<Self> {
47        let metrics = Arc::new(MetricsCollector::new()?);
48        let health_checker = Arc::new(HealthChecker::new(Arc::new(repository.pool().clone())));
49        let alert_manager = Arc::new(tokio::sync::RwLock::new(AlertManager::new()));
50        let profiler = Arc::new(PerformanceProfiler::new());
51
52        Ok(Self {
53            repository,
54            embedder,
55            metrics,
56            health_checker,
57            alert_manager,
58            profiler,
59        })
60    }
61
62    pub async fn start(&self, addr: SocketAddr) -> Result<()> {
63        let handler = self.create_handler().await;
64
65        // Start background monitoring task
66        self.start_monitoring_task().await;
67
68        let server = ServerBuilder::new(handler)
69            .start(&addr)
70            .map_err(|e| anyhow::anyhow!("Failed to start MCP server: {}", e))?;
71
72        info!("MCP server listening on {}", addr);
73
74        // Use tokio::spawn to handle the blocking wait call
75        let server_handle = tokio::task::spawn_blocking(move || {
76            server.wait();
77        });
78
79        // Wait for the server to finish
80        server_handle
81            .await
82            .map_err(|e| anyhow::anyhow!("Server task failed: {}", e))?;
83        Ok(())
84    }
85
86    async fn start_monitoring_task(&self) {
87        let health_checker = self.health_checker.clone();
88        let alert_manager = self.alert_manager.clone();
89        let metrics = self.metrics.clone();
90
91        tokio::spawn(async move {
92            let mut interval = tokio::time::interval(tokio::time::Duration::from_secs(30));
93
94            loop {
95                interval.tick().await;
96
97                // Update derived metrics
98                metrics.update_derived_metrics();
99
100                // Run health checks and evaluate alerts
101                match health_checker.check_system_health().await {
102                    Ok(health) => {
103                        let mut alert_mgr = alert_manager.write().await;
104                        alert_mgr.evaluate_alerts(&health, None);
105
106                        // Cleanup old alerts (keep 24 hours of history)
107                        alert_mgr.cleanup_old_alerts(24);
108                    }
109                    Err(e) => {
110                        tracing::error!("Health check failed: {}", e);
111                    }
112                }
113            }
114        });
115
116        info!("Started background monitoring task");
117    }
118
119    async fn create_handler(&self) -> IoHandler {
120        let mut handler = IoHandler::new();
121
122        // Memory operations
123        let repo = self.repository.clone();
124        let embedder = self.embedder.clone();
125        let metrics = self.metrics.clone();
126        let profiler = self.profiler.clone();
127        handler.add_method("memory.create", move |params: Params| {
128            let repo = repo.clone();
129            let embedder = embedder.clone();
130            let metrics = metrics.clone();
131            let profiler = profiler.clone();
132            Box::pin(async move {
133                let _trace = profiler.start_trace("memory.create".to_string());
134                let start_time = std::time::Instant::now();
135
136                let mut request: CreateMemoryRequest = params.parse()?;
137
138                // Generate embedding if not provided
139                if request.embedding.is_none() {
140                    if let Ok(embedding) = embedder.generate_embedding(&request.content).await {
141                        request.embedding = Some(embedding);
142                    }
143                }
144
145                let result = repo.create_memory(request).await;
146
147                match &result {
148                    Ok(_) => {
149                        metrics.record_request(start_time);
150                        metrics.memory_creation_total.inc();
151                    }
152                    Err(_) => {
153                        metrics.record_db_query(start_time, false);
154                    }
155                }
156
157                let memory = result.map_err(|_| jsonrpc_core::Error::internal_error())?;
158                serde_json::to_value(memory).map_err(|_| jsonrpc_core::Error::internal_error())
159            })
160        });
161
162        let repo = self.repository.clone();
163        handler.add_method("memory.get", move |params: Params| {
164            let repo = repo.clone();
165            Box::pin(async move {
166                let (id,): (Uuid,) = params.parse()?;
167                let memory = repo
168                    .get_memory(id)
169                    .await
170                    .map_err(|_| jsonrpc_core::Error::internal_error())?;
171                serde_json::to_value(memory).map_err(|_| jsonrpc_core::Error::internal_error())
172            })
173        });
174
175        let repo = self.repository.clone();
176        handler.add_method("memory.update", move |params: Params| {
177            let repo = repo.clone();
178            Box::pin(async move {
179                let (id, request): (Uuid, UpdateMemoryRequest) = params.parse()?;
180                let memory = repo
181                    .update_memory(id, request)
182                    .await
183                    .map_err(|_| jsonrpc_core::Error::internal_error())?;
184                serde_json::to_value(memory).map_err(|_| jsonrpc_core::Error::internal_error())
185            })
186        });
187
188        let repo = self.repository.clone();
189        handler.add_method("memory.delete", move |params: Params| {
190            let repo = repo.clone();
191            Box::pin(async move {
192                let (id,): (Uuid,) = params.parse()?;
193                repo.delete_memory(id)
194                    .await
195                    .map_err(|_| jsonrpc_core::Error::internal_error())?;
196                Ok(Value::Bool(true))
197            })
198        });
199
200        let repo = self.repository.clone();
201        let embedder = self.embedder.clone();
202        let metrics2 = self.metrics.clone();
203        let profiler2 = self.profiler.clone();
204        handler.add_method("memory.search", move |params: Params| {
205            let repo = repo.clone();
206            let embedder = embedder.clone();
207            let metrics = metrics2.clone();
208            let profiler = profiler2.clone();
209            Box::pin(async move {
210                let _trace = profiler.start_trace("memory.search".to_string());
211                let start_time = std::time::Instant::now();
212
213                let mut request: SearchRequest = params.parse()?;
214
215                // Generate query embedding if needed
216                if let Some(ref query_text) = request.query_text {
217                    if request.query_embedding.is_none() {
218                        if let Ok(embedding) = embedder.generate_embedding(query_text).await {
219                            request.query_embedding = Some(embedding);
220                        }
221                    }
222                }
223
224                let result = repo.search_memories(request).await;
225
226                match &result {
227                    Ok(response) => {
228                        metrics.record_search(start_time, response.results.len(), false);
229                    }
230                    Err(_) => {
231                        metrics.record_db_query(start_time, false);
232                    }
233                }
234
235                let results = result.map_err(|_| jsonrpc_core::Error::internal_error())?;
236                serde_json::to_value(results).map_err(|_| jsonrpc_core::Error::internal_error())
237            })
238        });
239
240        // Tier management
241        let repo = self.repository.clone();
242        handler.add_method("memory.migrate", move |params: Params| {
243            let repo = repo.clone();
244            Box::pin(async move {
245                let (id, tier, reason): (Uuid, MemoryTier, Option<String>) = params.parse()?;
246                let memory = repo
247                    .migrate_memory(id, tier, reason)
248                    .await
249                    .map_err(|_| jsonrpc_core::Error::internal_error())?;
250                serde_json::to_value(memory).map_err(|_| jsonrpc_core::Error::internal_error())
251            })
252        });
253
254        // Statistics
255        let repo = self.repository.clone();
256        handler.add_method("memory.statistics", move |_params: Params| {
257            let repo = repo.clone();
258            Box::pin(async move {
259                let stats = repo
260                    .get_statistics()
261                    .await
262                    .map_err(|_| jsonrpc_core::Error::internal_error())?;
263                serde_json::to_value(stats).map_err(|_| jsonrpc_core::Error::internal_error())
264            })
265        });
266
267        // Health check
268        let health_checker = self.health_checker.clone();
269        handler.add_method("health", move |_params| {
270            let health_checker = health_checker.clone();
271            Box::pin(async move {
272                match health_checker.check_system_health().await {
273                    Ok(health) => serde_json::to_value(health)
274                        .map_err(|_| jsonrpc_core::Error::internal_error()),
275                    Err(_) => Ok(Value::Object(serde_json::Map::from_iter([
276                        ("status".to_string(), Value::String("unhealthy".to_string())),
277                        (
278                            "timestamp".to_string(),
279                            Value::String(chrono::Utc::now().to_rfc3339()),
280                        ),
281                    ]))),
282                }
283            })
284        });
285
286        // Prometheus metrics endpoint
287        let metrics = self.metrics.clone();
288        handler.add_method("metrics", move |_params| {
289            let metrics = metrics.clone();
290            Box::pin(async move {
291                let metrics_text = metrics.gather_metrics();
292                Ok(Value::String(metrics_text))
293            })
294        });
295
296        // Performance summary
297        let profiler = self.profiler.clone();
298        handler.add_method("performance", move |_params| {
299            let profiler = profiler.clone();
300            Box::pin(async move {
301                let summary = profiler.get_performance_summary();
302                serde_json::to_value(summary).map_err(|_| jsonrpc_core::Error::internal_error())
303            })
304        });
305
306        // Active alerts
307        let alert_manager = self.alert_manager.clone();
308        handler.add_method("alerts", move |_params| {
309            let alert_manager = alert_manager.clone();
310            Box::pin(async move {
311                let binding = alert_manager.read().await;
312                let alerts = binding.get_active_alerts();
313                serde_json::to_value(alerts).map_err(|_| jsonrpc_core::Error::internal_error())
314            })
315        });
316
317        handler
318    }
319}
320
321#[derive(Debug, Clone, Serialize, Deserialize)]
322pub struct McpRequest<T> {
323    pub id: String,
324    pub method: String,
325    pub params: T,
326    pub correlation_id: String,
327}
328
329#[derive(Debug, Clone, Serialize, Deserialize)]
330pub struct McpResponse<T> {
331    pub id: String,
332    pub result: Option<T>,
333    pub error: Option<McpError>,
334    pub correlation_id: String,
335}
336
337#[derive(Debug, Clone, Serialize, Deserialize)]
338pub struct McpError {
339    pub code: i32,
340    pub message: String,
341    pub data: Option<Value>,
342}
343
344#[cfg(test)]
345mod tests {
346    use super::*;
347
348    #[test]
349    fn test_default_config() {
350        let config = McpServerConfig::default();
351        assert_eq!(config.max_connections, 1000);
352        assert_eq!(config.request_timeout_ms, 30000);
353    }
354}