1use serde_json::{Value, json};
2use std::path::Path;
3use std::sync::{Arc, Mutex, OnceLock};
4use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
5use tracing::error;
6
7use crate::actions::memory;
8use crate::config::Config;
9use crate::errors::{MCSError, Result};
10use crate::kg::KnowledgeGraph;
11use crate::protocol::{JsonRpcRequest, JsonRpcResponse};
12use crate::tools;
13
14const BUFFER_CAPACITY: usize = 65536;
15const NEWLINE: &[u8] = b"\n";
16
17enum LineRead {
18 Line,
19 Eof,
20 TooLong,
21}
22
23async fn read_line_capped<R>(reader: &mut R, out: &mut String, max: usize) -> std::io::Result<LineRead>
24where
25 R: AsyncBufReadExt + Unpin,
26{
27 out.clear();
28 let mut buf: Vec<u8> = Vec::new();
29 loop {
30 let available = reader.fill_buf().await?;
31 if available.is_empty() {
32 if buf.is_empty() {
33 return Ok(LineRead::Eof);
34 }
35 *out = String::from_utf8(buf.clone()).map_err(|_| {
36 std::io::Error::new(std::io::ErrorKind::InvalidData, "Non-UTF-8 input")
37 })?;
38 return Ok(LineRead::Line);
39 }
40 match available.iter().position(|&b| b == b'\n') {
41 Some(i) => {
42 if buf.len() + i + 1 > max {
43 reader.consume(i + 1);
44 return Ok(LineRead::TooLong);
45 }
46 buf.extend_from_slice(&available[..=i]);
47 reader.consume(i + 1);
48 *out = String::from_utf8(buf.clone()).map_err(|_| {
49 std::io::Error::new(std::io::ErrorKind::InvalidData, "Non-UTF-8 input")
50 })?;
51 return Ok(LineRead::Line);
52 }
53 None => {
54 let take = available.len();
55 if buf.len() + take > max {
56 reader.consume(take);
57 return Ok(LineRead::TooLong);
58 }
59 buf.extend_from_slice(available);
60 reader.consume(take);
61 }
62 }
63 }
64}
65
66fn parse_error(msg: String) -> JsonRpcResponse {
67 let mcp_error = MCSError::ParseError(msg);
68 JsonRpcResponse::error(None, mcp_error.error_code(), mcp_error.to_string())
69}
70
71fn parse_request(line: &str) -> std::result::Result<JsonRpcRequest, String> {
72 let trimmed = line.trim();
73 if trimmed.is_empty() {
74 return Err("Empty request".to_string());
75 }
76 serde_json::from_str::<JsonRpcRequest>(trimmed).map_err(|e| e.to_string())
77}
78
79pub struct MCPServer {
80 _config: Arc<Config>,
81 kg: Arc<Mutex<KnowledgeGraph>>,
82}
83
84impl MCPServer {
85 pub fn new(config: Config) -> Result<Self> {
86 let path = Path::new(&config.memory_file_path);
87 let kg = KnowledgeGraph::new(path)
88 .map_err(MCSError::IoError)?;
89
90 Ok(Self {
91 _config: Arc::new(config),
92 kg: Arc::new(Mutex::new(kg)),
93 })
94 }
95
96 pub async fn run_stdio(&self) -> Result<()> {
97 let stdin = tokio::io::stdin();
98 let mut reader = BufReader::with_capacity(BUFFER_CAPACITY, stdin);
99 let mut stdout = tokio::io::stdout();
100 let mut line = String::with_capacity(1024);
101 let mut response_buf = Vec::with_capacity(65536);
102 let max = 16 * 1024 * 1024;
103
104 loop {
105 match read_line_capped(&mut reader, &mut line, max).await {
106 Ok(LineRead::Eof) => break,
107 Ok(LineRead::Line) => {
108 process_one_line(&line, &self.kg, &mut response_buf, &mut stdout).await?;
109 }
110 Ok(LineRead::TooLong) => {
111 let err = MCSError::InvalidParams("Request exceeds maximum size of 16MB".into());
112 let response = JsonRpcResponse::error(None, err.error_code(), err.to_string());
113 response_buf.clear();
114 serde_json::to_writer(&mut response_buf, &response).map_err(MCSError::JsonError)?;
115 response_buf.extend_from_slice(NEWLINE);
116 stdout.write_all(&response_buf).await.map_err(MCSError::IoError)?;
117 stdout.flush().await.map_err(MCSError::IoError)?;
118 break;
119 }
120 Err(e) => {
121 error!("IO error: {}", e);
122 break;
123 }
124 }
125 }
126 Ok(())
127 }
128}
129
130async fn process_one_line<W: AsyncWriteExt + Unpin>(
131 line: &str,
132 kg: &Mutex<KnowledgeGraph>,
133 response_buf: &mut Vec<u8>,
134 writer: &mut W,
135) -> Result<()> {
136 let (response, is_notification) = match parse_request(line) {
137 Ok(req) => {
138 let is_notif = req.id.is_none();
139 match process_request(&req, kg) {
140 Ok(result) => (JsonRpcResponse::success(req.id, result), is_notif),
141 Err(e) => (JsonRpcResponse::error(req.id, e.error_code(), e.to_string()), is_notif),
142 }
143 }
144 Err(e) => (parse_error(e), false),
145 };
146
147 if is_notification {
148 return Ok(());
149 }
150
151 response_buf.clear();
152 serde_json::to_writer(&mut *response_buf, &response).map_err(MCSError::JsonError)?;
153 response_buf.extend_from_slice(NEWLINE);
154
155 writer.write_all(response_buf).await.map_err(MCSError::IoError)?;
156 writer.flush().await.map_err(MCSError::IoError)?;
157 Ok(())
158}
159
160fn process_request(req: &JsonRpcRequest, kg: &Mutex<KnowledgeGraph>) -> Result<Value> {
161 match req.method.as_str() {
162 "initialize" => handle_initialize(),
163 "tools/list" => handle_tools_list(),
164 "tools/call" => handle_tools_call(req, kg),
165 "ping" => handle_ping(),
166 method if method.starts_with("notifications/") => handle_notification(method),
167 _ => Err(MCSError::MethodNotFound(req.method.clone())),
168 }
169}
170
171const fn handle_ping() -> Result<Value> {
172 Ok(Value::Null)
173}
174
175fn handle_notification(method: &str) -> Result<Value> {
176 tracing::trace!("Received notification: {method}");
177 Ok(Value::Null)
178}
179
180fn handle_initialize() -> Result<Value> {
181 Ok(json!({
182 "protocolVersion": "2024-11-05",
183 "capabilities": {
184 "tools": { "listChanged": false }
185 },
186 "serverInfo": {
187 "name": "mcp-memory",
188 "version": env!("CARGO_PKG_VERSION")
189 }
190 }))
191}
192
193fn handle_tools_list() -> Result<Value> {
194 static CACHED: OnceLock<Value> = OnceLock::new();
195 if let Some(cached) = CACHED.get() {
196 return Ok(cached.clone());
197 }
198 let tools_json = include_str!("../tools.json");
199 let tools: Vec<Value> =
200 serde_json::from_str(tools_json).map_err(MCSError::JsonError)?;
201 let result = json!({ "tools": tools });
202 let _ = CACHED.set(result.clone());
203 Ok(result)
204}
205
206fn handle_tools_call(req: &JsonRpcRequest, kg: &Mutex<KnowledgeGraph>) -> Result<Value> {
207 let tool_name = req
208 .params
209 .as_ref()
210 .and_then(|p| p.get("name").and_then(|v| v.as_str()))
211 .ok_or_else(|| MCSError::InvalidParams("Missing 'name' parameter".into()))?;
212
213 let tool_args = req.params.as_ref().and_then(|p| p.get("arguments"));
214
215 if !tools::tool_exists(tool_name) {
216 return Err(MCSError::MethodNotFound(tool_name.to_string()));
217 }
218
219 match tool_name {
220 "create_entities" => memory::handle_create_entities(kg, tool_args),
221 "create_relations" => memory::handle_create_relations(kg, tool_args),
222 "add_observations" => memory::handle_add_observations(kg, tool_args),
223 "delete_entities" => memory::handle_delete_entities(kg, tool_args),
224 "delete_observations" => memory::handle_delete_observations(kg, tool_args),
225 "delete_relations" => memory::handle_delete_relations(kg, tool_args),
226 "read_graph" => memory::handle_read_graph(kg),
227 "search_nodes" => memory::handle_search_nodes(kg, tool_args),
228 "open_nodes" => memory::handle_open_nodes(kg, tool_args),
229 "get_entity" => memory::handle_get_entity(kg, tool_args),
230 "graph_stats" => memory::handle_graph_stats(kg),
231 "search_relations" => memory::handle_search_relations(kg, tool_args),
232 "find_path" => memory::handle_find_path(kg, tool_args),
233 "compact" => memory::handle_compact(kg),
234 tool => Err(MCSError::MethodNotFound(tool.to_string())),
235 }
236}
237
238#[cfg(test)]
239mod tests {
240 use super::*;
241 use std::sync::atomic::{AtomicU64, Ordering};
242
243 static COUNTER: AtomicU64 = AtomicU64::new(0);
244
245 fn setup_kg() -> (Arc<Mutex<KnowledgeGraph>>, String) {
246 let pid = std::process::id();
247 let seq = COUNTER.fetch_add(1, Ordering::SeqCst);
248 let path = format!("/tmp/mcp_mem_test_{pid}_{seq}.bin");
249 let kg = KnowledgeGraph::new(Path::new(&path)).unwrap();
250 (Arc::new(Mutex::new(kg)), path)
251 }
252
253 fn cleanup(path: &str) {
254 let _ = std::fs::remove_file(path);
255 }
256
257 #[test]
258 fn test_parse_valid_request() {
259 let line = r#"{"jsonrpc":"2.0","method":"initialize","id":1}"#;
260 let req = parse_request(line).unwrap();
261 assert_eq!(req.method, "initialize");
262 }
263
264 #[test]
265 fn test_parse_invalid_json() {
266 let err = parse_request("{invalid}").unwrap_err();
267 assert!(!err.is_empty());
268 }
269
270 #[test]
271 fn test_handle_initialize_response() {
272 let (kg, path) = setup_kg();
273 let req = JsonRpcRequest {
274 jsonrpc: "2.0".to_string(),
275 method: "initialize".to_string(),
276 params: None,
277 id: Some(Value::Number(1.into())),
278 };
279 let result = process_request(&req, &kg).unwrap();
280 assert_eq!(result["protocolVersion"], "2024-11-05");
281 assert_eq!(result["serverInfo"]["name"], "mcp-memory");
282 cleanup(&path);
283 }
284}