1pub mod handlers;
3pub mod tools;
4pub mod transport;
5
6pub use handlers::MCPHandlers;
8
9use crate::config::Config;
10use crate::error::Result;
11use crate::storage::Storage;
12use std::sync::Arc;
13use std::time::{Duration, Instant};
14use tokio::io::AsyncWriteExt;
15use tokio_util::codec::{Decoder, FramedRead};
16use futures_util::StreamExt;
17use tracing::{error, info, warn};
18
19pub struct MCPServer {
21 _config: Config,
22 handlers: Arc<MCPHandlers>,
23 start_time: Instant,
24 last_request: Arc<std::sync::Mutex<Instant>>,
25}
26
27impl MCPServer {
28 pub fn new(config: Config, storage: Arc<Storage>) -> Self {
30 let handlers = Arc::new(MCPHandlers::new(storage));
31 let now = Instant::now();
32 Self {
33 _config: config,
34 handlers,
35 start_time: now,
36 last_request: Arc::new(std::sync::Mutex::new(now)),
37 }
38 }
39
40 fn should_terminate(&self) -> bool {
42 let last_request = *self.last_request.lock().unwrap();
43 let inactive_duration = last_request.elapsed();
44
45 if inactive_duration > Duration::from_secs(86400) {
47 warn!(
48 "Server inactive for {:?}, initiating shutdown",
49 inactive_duration
50 );
51 return true;
52 }
53
54 false
55 }
56
57 fn update_last_request(&self) {
59 *self.last_request.lock().unwrap() = Instant::now();
60 }
61
62 async fn health_monitor(&self) {
64 let mut interval = tokio::time::interval(Duration::from_secs(60)); loop {
67 interval.tick().await;
68
69 if self.should_terminate() {
70 error!("Health monitor detected inactivity timeout, terminating process");
71 std::process::exit(1);
72 }
73
74 let uptime = self.start_time.elapsed();
75 let last_request_ago = self.last_request.lock().unwrap().elapsed();
76
77 info!(
78 "Health check: uptime={:?}, last_request={:?} ago",
79 uptime, last_request_ago
80 );
81 }
82 }
83
84 pub async fn run_stdio(&self) -> Result<()> {
86 info!("MCP server running in stdio mode with secure JSON streaming");
87
88 let health_monitor = {
90 let server_clone = Self {
91 _config: self._config.clone(),
92 handlers: Arc::clone(&self.handlers),
93 start_time: self.start_time,
94 last_request: Arc::clone(&self.last_request),
95 };
96 tokio::spawn(async move {
97 server_clone.health_monitor().await;
98 })
99 };
100
101 let stdin = tokio::io::stdin();
102 let stdout = tokio::io::stdout();
103 let mut stdout = stdout;
104
105 let mut framed = FramedRead::new(stdin, SecureJsonDecoder::new());
107
108 loop {
109 tokio::select! {
110 message = framed.next() => {
112 match message {
113 Some(Ok(json_str)) => {
114 info!("Processing JSON request ({} chars)", json_str.len());
115 self.update_last_request();
116 let response = self.handle_request(&json_str).await;
117 if !response.is_empty() {
118 stdout.write_all(response.as_bytes()).await?;
119 stdout.write_all(b"\n").await?;
120 stdout.flush().await?;
121 }
122 }
123 Some(Err(e)) => {
124 let error_msg = e.to_string();
126 if !error_msg.contains("bytes remaining on stream") {
127 error!("JSON decode error: {}", e);
128 let parse_error = crate::error::Error::ParseError(e.to_string());
130 let error_response = parse_error.to_json_rpc_error(None);
131 stdout.write_all(serde_json::to_string(&error_response).unwrap().as_bytes()).await?;
132 stdout.flush().await?;
133 }
134 }
135 None => {
136 info!("Received EOF, shutting down MCP server");
137 break;
138 }
139 }
140 }
141 _ = tokio::time::sleep(Duration::from_secs(60)) => {
143 if self.should_terminate() {
144 warn!("MCP server inactive for too long, initiating graceful shutdown");
145 break;
146 }
147 }
148 }
149 }
150
151 info!("MCP server shutting down gracefully");
152
153 health_monitor.abort();
155
156 Ok(())
157 }
158
159 pub async fn handle_request(&self, request: &str) -> String {
163 info!("Raw request to parse: {:?}", request);
165
166 let request: serde_json::Value = match serde_json::from_str(request) {
167 Ok(v) => v,
168 Err(e) => {
169 error!("JSON parse error: {} - Request: {:?}", e, request);
170 let parse_error = crate::error::Error::ParseError(e.to_string());
171 return serde_json::to_string(&parse_error.to_json_rpc_error(Some(serde_json::json!(0))))
172 .unwrap_or_else(|_| r#"{"jsonrpc":"2.0","id":0,"error":{"code":-32700,"message":"Parse error"}}"#.to_string());
173 }
174 };
175
176 let method = request["method"].as_str().unwrap_or("");
178 if method.is_empty() {
179 let invalid_request_error = crate::error::Error::InvalidRequest("Missing 'method' field".to_string());
180 return serde_json::to_string(&invalid_request_error.to_json_rpc_error(request.get("id").cloned()))
181 .unwrap_or_else(|_| r#"{"jsonrpc":"2.0","id":null,"error":{"code":-32600,"message":"Invalid Request"}}"#.to_string());
182 }
183
184 let params = request.get("params").cloned().unwrap_or_default();
185 let id = request.get("id").cloned();
186
187 let result = match method {
188 "initialize" => Ok(serde_json::json!({
189 "protocolVersion": "2024-11-05",
190 "capabilities": {
191 "tools": {}
192 },
193 "serverInfo": {
194 "name": "codex-memory",
195 "version": env!("CARGO_PKG_VERSION")
196 }
197 })),
198 "tools/list" => Ok(serde_json::json!({
199 "tools": tools::MCPTools::get_tools_list()
200 })),
201 "tools/call" => {
202 let tool_name = params["name"].as_str().unwrap_or("");
203 let tool_params = params.get("arguments").cloned().unwrap_or_default();
204
205 let timeout_duration = std::time::Duration::from_secs(60);
207
208 match tokio::time::timeout(timeout_duration,
209 self.handlers.handle_tool_call(tool_name, tool_params)
210 ).await {
211 Ok(result) => result,
212 Err(_) => Err(crate::error::Error::Timeout(format!(
213 "Tool call '{}' timed out after {} seconds",
214 tool_name,
215 timeout_duration.as_secs()
216 )))
217 }
218 }
219 "prompts/list" => {
220 Ok(serde_json::json!({
222 "prompts": []
223 }))
224 }
225 "resources/list" => {
226 Ok(serde_json::json!({
228 "resources": []
229 }))
230 }
231 "notifications/initialized" => {
232 return "".to_string(); }
235 _ => {
236 Err(crate::error::Error::MethodNotFound(format!(
238 "Unknown method: {}. Supported methods: initialize, tools/list, tools/call, prompts/list, resources/list, notifications/initialized",
239 method
240 )))
241 }
242 };
243
244 match result {
245 Ok(value) => {
246 if let Some(id) = id {
247 format!(r#"{{"jsonrpc":"2.0","id":{},"result":{}}}"#, id, value)
248 } else {
249 format!(r#"{{"jsonrpc":"2.0","result":{}}}"#, value)
250 }
251 }
252 Err(e) => {
253 error!("MCP request failed - Method: {}, Error: {}", method, e);
255
256 let error_response = e.to_json_rpc_error(id.or_else(|| Some(serde_json::json!(0))));
258 serde_json::to_string(&error_response)
259 .unwrap_or_else(|_| r#"{"jsonrpc":"2.0","id":0,"error":{"code":-32603,"message":"Internal error"}}"#.to_string())
260 }
261 }
262 }
263}
264
265struct SecureJsonDecoder {
267 max_buffer_size: usize,
269}
270
271impl SecureJsonDecoder {
272 fn new() -> Self {
273 Self {
274 max_buffer_size: 10 * 1024 * 1024, }
276 }
277}
278
279impl Decoder for SecureJsonDecoder {
280 type Item = String;
281 type Error = std::io::Error;
282
283 fn decode(&mut self, src: &mut bytes::BytesMut) -> std::result::Result<Option<Self::Item>, Self::Error> {
284 if src.len() > self.max_buffer_size {
286 return Err(std::io::Error::new(
287 std::io::ErrorKind::InvalidData,
288 format!(
289 "Buffer size limit exceeded: {} bytes (max: {})",
290 src.len(),
291 self.max_buffer_size
292 ),
293 ));
294 }
295
296 match std::str::from_utf8(src) {
298 Ok(_) => {}, Err(_) => {
300 return Err(std::io::Error::new(
301 std::io::ErrorKind::InvalidData,
302 "Invalid UTF-8 encoding in JSON stream",
303 ));
304 }
305 };
306
307 let mut depth = 0;
309 let mut in_string = false;
310 let mut escape_next = false;
311 let mut json_start = None;
312
313 for (i, byte) in src.iter().enumerate() {
314 let ch = *byte as char;
315
316 if escape_next {
317 escape_next = false;
318 continue;
319 }
320
321 match ch {
322 '\\' if in_string => escape_next = true,
323 '"' => in_string = !in_string,
324 '{' if !in_string => {
325 if json_start.is_none() {
326 json_start = Some(i);
327 }
328 depth += 1;
329 if depth > 100 {
331 return Err(std::io::Error::new(
332 std::io::ErrorKind::InvalidData,
333 "JSON nesting depth exceeded (max: 100 levels)",
334 ));
335 }
336 }
337 '}' if !in_string => {
338 depth -= 1;
339 if depth == 0 && json_start.is_some() {
340 let json_bytes = src.split_to(i + 1);
342
343 let json_str = match std::str::from_utf8(&json_bytes) {
345 Ok(s) => s.to_string(),
346 Err(e) => {
347 return Err(std::io::Error::new(
348 std::io::ErrorKind::InvalidData,
349 format!("Invalid UTF-8 in JSON: {}", e),
350 ));
351 }
352 };
353
354 match serde_json::from_str::<serde_json::Value>(&json_str) {
356 Ok(_) => return Ok(Some(json_str)),
357 Err(e) => {
358 return Err(std::io::Error::new(
359 std::io::ErrorKind::InvalidData,
360 format!("Invalid JSON structure: {}", e),
361 ));
362 }
363 }
364 }
365 }
366 _ => {}
367 }
368 }
369
370 Ok(None)
372 }
373}