codex_memory/mcp/
server.rs

1use crate::memory::{models::*, MemoryRepository};
2use crate::monitoring::{AlertManager, HealthChecker, MetricsCollector, PerformanceProfiler};
3use crate::SimpleEmbedder;
4use anyhow::Result;
5use jsonrpc_core::{IoHandler, Params, Value};
6use jsonrpc_tcp_server::ServerBuilder;
7use serde::{Deserialize, Serialize};
8use std::net::SocketAddr;
9use std::sync::Arc;
10use tracing::info;
11use uuid::Uuid;
12
13#[derive(Clone)]
14pub struct McpServerConfig {
15    pub tcp_addr: SocketAddr,
16    pub unix_socket_path: Option<String>,
17    pub max_connections: u32,
18    pub request_timeout_ms: u64,
19    pub max_request_size: usize,
20    pub enable_compression: bool,
21}
22
23impl Default for McpServerConfig {
24    fn default() -> Self {
25        Self {
26            tcp_addr: ([127, 0, 0, 1], 3333).into(),
27            unix_socket_path: None,
28            max_connections: 1000,
29            request_timeout_ms: 30000,
30            max_request_size: 10 * 1024 * 1024, // 10MB
31            enable_compression: true,
32        }
33    }
34}
35
36pub struct MCPServer {
37    repository: Arc<MemoryRepository>,
38    embedder: Arc<SimpleEmbedder>,
39    metrics: Arc<MetricsCollector>,
40    health_checker: Arc<HealthChecker>,
41    alert_manager: Arc<tokio::sync::RwLock<AlertManager>>,
42    profiler: Arc<PerformanceProfiler>,
43}
44
45impl MCPServer {
46    pub fn new(repository: Arc<MemoryRepository>, embedder: Arc<SimpleEmbedder>) -> Result<Self> {
47        let metrics = Arc::new(MetricsCollector::new()?);
48        let health_checker = Arc::new(HealthChecker::new(Arc::new(repository.pool().clone())));
49        let alert_manager = Arc::new(tokio::sync::RwLock::new(AlertManager::new()));
50        let profiler = Arc::new(PerformanceProfiler::new());
51
52        Ok(Self {
53            repository,
54            embedder,
55            metrics,
56            health_checker,
57            alert_manager,
58            profiler,
59        })
60    }
61
62    pub async fn start(&self, addr: SocketAddr) -> Result<()> {
63        let handler = self.create_handler().await;
64
65        // Start background monitoring task
66        self.start_monitoring_task().await;
67
68        let server = ServerBuilder::new(handler)
69            .start(&addr)
70            .map_err(|e| anyhow::anyhow!("Failed to start MCP server: {}", e))?;
71
72        info!("MCP server listening on {}", addr);
73
74        // Use tokio::spawn to handle the blocking wait call
75        let server_handle = tokio::task::spawn_blocking(move || {
76            server.wait();
77        });
78
79        // Wait for the server to finish
80        server_handle
81            .await
82            .map_err(|e| anyhow::anyhow!("Server task failed: {}", e))?;
83        Ok(())
84    }
85
86    async fn start_monitoring_task(&self) {
87        let health_checker = self.health_checker.clone();
88        let alert_manager = self.alert_manager.clone();
89        let metrics = self.metrics.clone();
90
91        tokio::spawn(async move {
92            let mut interval = tokio::time::interval(tokio::time::Duration::from_secs(30));
93
94            loop {
95                interval.tick().await;
96
97                // Update derived metrics
98                metrics.update_derived_metrics();
99
100                // Run health checks and evaluate alerts
101                match health_checker.check_system_health().await {
102                    Ok(health) => {
103                        let mut alert_mgr = alert_manager.write().await;
104                        alert_mgr.evaluate_alerts(&health, None);
105
106                        // Cleanup old alerts (keep 24 hours of history)
107                        alert_mgr.cleanup_old_alerts(24);
108                    }
109                    Err(e) => {
110                        tracing::error!("Health check failed: {}", e);
111                    }
112                }
113            }
114        });
115
116        info!("Started background monitoring task");
117    }
118
119    async fn create_handler(&self) -> IoHandler {
120        let mut handler = IoHandler::new();
121
122        // Memory operations
123        let repo = self.repository.clone();
124        let embedder = self.embedder.clone();
125        let metrics = self.metrics.clone();
126        let profiler = self.profiler.clone();
127        handler.add_method("memory.create", move |params: Params| {
128            let repo = repo.clone();
129            let embedder = embedder.clone();
130            let metrics = metrics.clone();
131            let profiler = profiler.clone();
132            Box::pin(async move {
133                let _trace = profiler.start_trace("memory.create".to_string());
134                let start_time = std::time::Instant::now();
135
136                let mut request: CreateMemoryRequest = params.parse()?;
137
138                // Generate embedding if not provided
139                if request.embedding.is_none() {
140                    if let Ok(embedding) = embedder.generate_embedding(&request.content).await {
141                        request.embedding = Some(embedding);
142                    }
143                }
144
145                let result = repo.create_memory(request).await;
146
147                match &result {
148                    Ok(_) => {
149                        metrics.record_request(start_time);
150                        metrics.memory_creation_total.inc();
151                    }
152                    Err(_) => {
153                        metrics.record_db_query(start_time, false);
154                    }
155                }
156
157                let memory = result.map_err(|_| jsonrpc_core::Error::internal_error())?;
158                Ok(serde_json::to_value(memory).unwrap())
159            })
160        });
161
162        let repo = self.repository.clone();
163        handler.add_method("memory.get", move |params: Params| {
164            let repo = repo.clone();
165            Box::pin(async move {
166                let (id,): (Uuid,) = params.parse()?;
167                let memory = repo
168                    .get_memory(id)
169                    .await
170                    .map_err(|_| jsonrpc_core::Error::internal_error())?;
171                Ok(serde_json::to_value(memory).unwrap())
172            })
173        });
174
175        let repo = self.repository.clone();
176        handler.add_method("memory.update", move |params: Params| {
177            let repo = repo.clone();
178            Box::pin(async move {
179                let (id, request): (Uuid, UpdateMemoryRequest) = params.parse()?;
180                let memory = repo
181                    .update_memory(id, request)
182                    .await
183                    .map_err(|_| jsonrpc_core::Error::internal_error())?;
184                Ok(serde_json::to_value(memory).unwrap())
185            })
186        });
187
188        let repo = self.repository.clone();
189        handler.add_method("memory.delete", move |params: Params| {
190            let repo = repo.clone();
191            Box::pin(async move {
192                let (id,): (Uuid,) = params.parse()?;
193                repo.delete_memory(id)
194                    .await
195                    .map_err(|_| jsonrpc_core::Error::internal_error())?;
196                Ok(Value::Bool(true))
197            })
198        });
199
200        let repo = self.repository.clone();
201        let embedder = self.embedder.clone();
202        let metrics2 = self.metrics.clone();
203        let profiler2 = self.profiler.clone();
204        handler.add_method("memory.search", move |params: Params| {
205            let repo = repo.clone();
206            let embedder = embedder.clone();
207            let metrics = metrics2.clone();
208            let profiler = profiler2.clone();
209            Box::pin(async move {
210                let _trace = profiler.start_trace("memory.search".to_string());
211                let start_time = std::time::Instant::now();
212
213                let mut request: SearchRequest = params.parse()?;
214
215                // Generate query embedding if needed
216                if let Some(ref query_text) = request.query_text {
217                    if request.query_embedding.is_none() {
218                        if let Ok(embedding) = embedder.generate_embedding(query_text).await {
219                            request.query_embedding = Some(embedding);
220                        }
221                    }
222                }
223
224                let result = repo.search_memories(request).await;
225
226                match &result {
227                    Ok(response) => {
228                        metrics.record_search(start_time, response.results.len(), false);
229                    }
230                    Err(_) => {
231                        metrics.record_db_query(start_time, false);
232                    }
233                }
234
235                let results = result.map_err(|_| jsonrpc_core::Error::internal_error())?;
236                Ok(serde_json::to_value(results).unwrap())
237            })
238        });
239
240        // Tier management
241        let repo = self.repository.clone();
242        handler.add_method("memory.migrate", move |params: Params| {
243            let repo = repo.clone();
244            Box::pin(async move {
245                let (id, tier, reason): (Uuid, MemoryTier, Option<String>) = params.parse()?;
246                let memory = repo
247                    .migrate_memory(id, tier, reason)
248                    .await
249                    .map_err(|_| jsonrpc_core::Error::internal_error())?;
250                Ok(serde_json::to_value(memory).unwrap())
251            })
252        });
253
254        // Statistics
255        let repo = self.repository.clone();
256        handler.add_method("memory.statistics", move |_params: Params| {
257            let repo = repo.clone();
258            Box::pin(async move {
259                let stats = repo
260                    .get_statistics()
261                    .await
262                    .map_err(|_| jsonrpc_core::Error::internal_error())?;
263                Ok(serde_json::to_value(stats).unwrap())
264            })
265        });
266
267        // Health check
268        let health_checker = self.health_checker.clone();
269        handler.add_method("health", move |_params| {
270            let health_checker = health_checker.clone();
271            Box::pin(async move {
272                match health_checker.check_system_health().await {
273                    Ok(health) => Ok(serde_json::to_value(health).unwrap()),
274                    Err(_) => Ok(Value::Object(serde_json::Map::from_iter([
275                        ("status".to_string(), Value::String("unhealthy".to_string())),
276                        (
277                            "timestamp".to_string(),
278                            Value::String(chrono::Utc::now().to_rfc3339()),
279                        ),
280                    ]))),
281                }
282            })
283        });
284
285        // Prometheus metrics endpoint
286        let metrics = self.metrics.clone();
287        handler.add_method("metrics", move |_params| {
288            let metrics = metrics.clone();
289            Box::pin(async move {
290                let metrics_text = metrics.gather_metrics();
291                Ok(Value::String(metrics_text))
292            })
293        });
294
295        // Performance summary
296        let profiler = self.profiler.clone();
297        handler.add_method("performance", move |_params| {
298            let profiler = profiler.clone();
299            Box::pin(async move {
300                let summary = profiler.get_performance_summary();
301                Ok(serde_json::to_value(summary).unwrap())
302            })
303        });
304
305        // Active alerts
306        let alert_manager = self.alert_manager.clone();
307        handler.add_method("alerts", move |_params| {
308            let alert_manager = alert_manager.clone();
309            Box::pin(async move {
310                let binding = alert_manager.read().await;
311                let alerts = binding.get_active_alerts();
312                Ok(serde_json::to_value(alerts).unwrap())
313            })
314        });
315
316        handler
317    }
318}
319
320#[derive(Debug, Clone, Serialize, Deserialize)]
321pub struct McpRequest<T> {
322    pub id: String,
323    pub method: String,
324    pub params: T,
325    pub correlation_id: String,
326}
327
328#[derive(Debug, Clone, Serialize, Deserialize)]
329pub struct McpResponse<T> {
330    pub id: String,
331    pub result: Option<T>,
332    pub error: Option<McpError>,
333    pub correlation_id: String,
334}
335
336#[derive(Debug, Clone, Serialize, Deserialize)]
337pub struct McpError {
338    pub code: i32,
339    pub message: String,
340    pub data: Option<Value>,
341}
342
343#[cfg(test)]
344mod tests {
345    use super::*;
346
347    #[test]
348    fn test_default_config() {
349        let config = McpServerConfig::default();
350        assert_eq!(config.max_connections, 1000);
351        assert_eq!(config.request_timeout_ms, 30000);
352    }
353}