1use std::sync::Arc;
32
33use rmcp::handler::server::ServerHandler;
34use rmcp::model::{
35 CallToolRequestParam, CallToolResult, Content, Implementation, ListToolsResult,
36 PaginatedRequestParam, ProtocolVersion, ServerCapabilities, ServerInfo, Tool,
37 ToolsCapability,
38};
39use rmcp::service::{RequestContext, RoleServer};
40use rmcp::{Error as McpError, ServiceExt};
41use serde::{Deserialize, Serialize};
42use solo_core::{
43 Confidence, Embedder, EncodingContext, Episode, MemoryId, Tier,
44 VectorIndex,
45};
46use solo_storage::{ReaderPool, WriteHandle};
47use std::str::FromStr;
48
49#[derive(Clone)]
51pub struct SoloMcpServer {
52 inner: Arc<Inner>,
53}
54
55struct Inner {
56 write: WriteHandle,
57 pool: ReaderPool,
58 embedder: Arc<dyn Embedder>,
59 hnsw: Arc<dyn VectorIndex + Send + Sync>,
60}
61
62impl SoloMcpServer {
63 pub fn new(
64 write: WriteHandle,
65 pool: ReaderPool,
66 embedder: Arc<dyn Embedder>,
67 hnsw: Arc<dyn VectorIndex + Send + Sync>,
68 ) -> Self {
69 Self {
70 inner: Arc::new(Inner {
71 write,
72 pool,
73 embedder,
74 hnsw,
75 }),
76 }
77 }
78}
79
80pub async fn serve_stdio(server: SoloMcpServer) -> anyhow::Result<()> {
83 use rmcp::transport::io::stdio;
84 let (stdin, stdout) = stdio();
85 let running = server.serve((stdin, stdout)).await?;
86 running.waiting().await?;
87 Ok(())
88}
89
90#[derive(Debug, Clone, Serialize, Deserialize)]
95pub struct RememberArgs {
96 pub content: String,
97 #[serde(default)]
98 pub source_type: Option<String>,
99 #[serde(default)]
100 pub source_id: Option<String>,
101}
102
103#[derive(Debug, Clone, Serialize, Deserialize)]
104pub struct RecallArgs {
105 pub query: String,
106 #[serde(default = "default_limit")]
107 pub limit: usize,
108}
109
110fn default_limit() -> usize {
111 5
112}
113
114#[derive(Debug, Clone, Serialize, Deserialize)]
115pub struct ForgetArgs {
116 pub memory_id: String,
117 #[serde(default = "default_forget_reason")]
118 pub reason: String,
119}
120
121fn default_forget_reason() -> String {
122 "user-initiated via MCP".into()
123}
124
125#[derive(Debug, Clone, Serialize, Deserialize)]
126pub struct InspectArgs {
127 pub memory_id: String,
128}
129
130impl ServerHandler for SoloMcpServer {
135 fn get_info(&self) -> ServerInfo {
136 ServerInfo {
137 protocol_version: ProtocolVersion::default(),
138 capabilities: ServerCapabilities {
139 tools: Some(ToolsCapability {
140 list_changed: Some(false),
141 }),
142 ..Default::default()
143 },
144 server_info: Implementation {
145 name: "solo".into(),
146 version: env!("CARGO_PKG_VERSION").into(),
147 },
148 instructions: Some(
149 "Solo: local-first personal memory for LLMs. Use \
150 memory.remember to store, memory.recall to search, \
151 memory.forget to soft-delete, and memory.inspect to \
152 fetch a full record."
153 .into(),
154 ),
155 }
156 }
157
158 async fn list_tools(
159 &self,
160 _request: PaginatedRequestParam,
161 _context: RequestContext<RoleServer>,
162 ) -> std::result::Result<ListToolsResult, McpError> {
163 Ok(ListToolsResult {
164 tools: build_tools(),
165 next_cursor: None,
166 })
167 }
168
169 async fn call_tool(
170 &self,
171 request: CallToolRequestParam,
172 _context: RequestContext<RoleServer>,
173 ) -> std::result::Result<CallToolResult, McpError> {
174 let CallToolRequestParam { name, arguments } = request;
175 let args_value = serde_json::Value::Object(arguments.unwrap_or_default());
176 self.dispatch_tool(&name, args_value).await
177 }
178}
179
180impl SoloMcpServer {
181 pub async fn dispatch_tool(
187 &self,
188 name: &str,
189 args_value: serde_json::Value,
190 ) -> std::result::Result<CallToolResult, McpError> {
191 match name {
192 "memory.remember" => {
193 let args: RememberArgs = parse_args(&args_value)?;
194 self.handle_remember(args).await
195 }
196 "memory.recall" => {
197 let args: RecallArgs = parse_args(&args_value)?;
198 self.handle_recall(args).await
199 }
200 "memory.forget" => {
201 let args: ForgetArgs = parse_args(&args_value)?;
202 self.handle_forget(args).await
203 }
204 "memory.inspect" => {
205 let args: InspectArgs = parse_args(&args_value)?;
206 self.handle_inspect(args).await
207 }
208 other => Err(McpError::invalid_params(
209 format!("unknown tool `{other}`"),
210 None,
211 )),
212 }
213 }
214
215 pub fn dispatch_list_tools(&self) -> Vec<Tool> {
218 build_tools()
219 }
220}
221
222fn parse_args<T: serde::de::DeserializeOwned>(
223 v: &serde_json::Value,
224) -> std::result::Result<T, McpError> {
225 serde_json::from_value(v.clone()).map_err(|e| {
226 McpError::invalid_params(format!("invalid tool arguments: {e}"), None)
227 })
228}
229
230fn solo_to_mcp(e: solo_core::Error) -> McpError {
231 use solo_core::Error;
232 match e {
233 Error::NotFound(msg) => McpError::invalid_params(msg, None),
234 Error::InvalidInput(msg) => McpError::invalid_params(msg, None),
235 Error::Conflict(msg) => McpError::invalid_params(msg, None),
236 other => McpError::internal_error(other.to_string(), None),
237 }
238}
239
240fn build_tools() -> Vec<Tool> {
245 vec![
246 Tool::new(
247 "memory.remember",
248 "Store a new episodic memory. Returns the new MemoryId (UUID v7).",
249 json_schema_object(serde_json::json!({
250 "type": "object",
251 "properties": {
252 "content": {
253 "type": "string",
254 "description": "The text to remember.",
255 },
256 "source_type": {
257 "type": "string",
258 "description": "Optional source-type tag (default: \"user_message\").",
259 },
260 "source_id": {
261 "type": "string",
262 "description": "Optional upstream id for traceability.",
263 },
264 },
265 "required": ["content"],
266 })),
267 ),
268 Tool::new(
269 "memory.recall",
270 "Vector-search the memory store. Returns up to `limit` results \
271 ordered by cosine distance (smaller = more similar). Excludes \
272 forgotten memories.",
273 json_schema_object(serde_json::json!({
274 "type": "object",
275 "properties": {
276 "query": {
277 "type": "string",
278 "description": "The query text.",
279 },
280 "limit": {
281 "type": "integer",
282 "description": "Maximum results (default 5).",
283 "minimum": 1,
284 "maximum": 100,
285 },
286 },
287 "required": ["query"],
288 })),
289 ),
290 Tool::new(
291 "memory.forget",
292 "Soft-delete a memory by id. The HNSW vector stays in the graph \
293 but the SQL row's status flips to 'forgotten' so future recalls \
294 exclude it.",
295 json_schema_object(serde_json::json!({
296 "type": "object",
297 "properties": {
298 "memory_id": {
299 "type": "string",
300 "description": "MemoryId to forget (UUID v7).",
301 },
302 "reason": {
303 "type": "string",
304 "description": "Optional free-form reason (logged, not yet persisted).",
305 },
306 },
307 "required": ["memory_id"],
308 })),
309 ),
310 Tool::new(
311 "memory.inspect",
312 "Return the full record for a memory_id (timestamps, source, \
313 status, scoring values, content).",
314 json_schema_object(serde_json::json!({
315 "type": "object",
316 "properties": {
317 "memory_id": {
318 "type": "string",
319 "description": "MemoryId to inspect (UUID v7).",
320 },
321 },
322 "required": ["memory_id"],
323 })),
324 ),
325 ]
326}
327
328fn json_schema_object(value: serde_json::Value) -> serde_json::Map<String, serde_json::Value> {
329 match value {
330 serde_json::Value::Object(map) => map,
331 _ => panic!("json_schema_object: input must be an object"),
332 }
333}
334
335impl SoloMcpServer {
340 async fn handle_remember(
341 &self,
342 args: RememberArgs,
343 ) -> std::result::Result<CallToolResult, McpError> {
344 let content = args.content.trim_end().to_string();
345 if content.is_empty() {
346 return Err(McpError::invalid_params(
347 "memory.remember: content must not be empty".to_string(),
348 None,
349 ));
350 }
351 let embedding: solo_core::Embedding = self
352 .inner
353 .embedder
354 .embed(&content)
355 .await
356 .map_err(solo_to_mcp)?;
357 let episode = Episode {
358 memory_id: MemoryId::new(),
359 ts_ms: chrono::Utc::now().timestamp_millis(),
360 source_type: args.source_type.unwrap_or_else(|| "user_message".into()),
361 source_id: args.source_id,
362 content,
363 encoding_context: EncodingContext::default(),
364 provenance: None,
365 confidence: Confidence::new(0.9).unwrap(),
366 strength: 0.5,
367 salience: 0.5,
368 tier: Tier::Hot,
369 };
370 let mid = self
371 .inner
372 .write
373 .remember(episode, embedding)
374 .await
375 .map_err(solo_to_mcp)?;
376 Ok(CallToolResult::success(vec![Content::text(format!(
377 "remembered {mid}"
378 ))]))
379 }
380
381 async fn handle_recall(
382 &self,
383 args: RecallArgs,
384 ) -> std::result::Result<CallToolResult, McpError> {
385 let result = solo_query::run_recall(
389 &self.inner.embedder,
390 &self.inner.hnsw,
391 &self.inner.pool,
392 &args.query,
393 args.limit,
394 )
395 .await
396 .map_err(solo_to_mcp)?;
397
398 if result.hits.is_empty() {
399 return Ok(CallToolResult::success(vec![Content::text(format!(
400 "no matches (index has {} vectors)",
401 result.index_len
402 ))]));
403 }
404 let body = serde_json::to_string_pretty(&result.hits).unwrap_or_else(|_| String::new());
405 Ok(CallToolResult::success(vec![Content::text(body)]))
406 }
407
408 async fn handle_forget(
409 &self,
410 args: ForgetArgs,
411 ) -> std::result::Result<CallToolResult, McpError> {
412 let mid = MemoryId::from_str(&args.memory_id).map_err(|e| {
413 McpError::invalid_params(format!("invalid memory_id: {e}"), None)
414 })?;
415 self.inner
416 .write
417 .forget(mid, args.reason)
418 .await
419 .map_err(solo_to_mcp)?;
420 Ok(CallToolResult::success(vec![Content::text(format!(
421 "forgotten {mid}"
422 ))]))
423 }
424
425 async fn handle_inspect(
426 &self,
427 args: InspectArgs,
428 ) -> std::result::Result<CallToolResult, McpError> {
429 let mid = MemoryId::from_str(&args.memory_id).map_err(|e| {
430 McpError::invalid_params(format!("invalid memory_id: {e}"), None)
431 })?;
432 let row = solo_query::inspect_one(&self.inner.pool, mid)
434 .await
435 .map_err(solo_to_mcp)?;
436 let body = serde_json::to_string_pretty(&row).unwrap_or_else(|_| String::new());
437 Ok(CallToolResult::success(vec![Content::text(body)]))
438 }
439}
440
441#[cfg(test)]
442mod dispatch_tests {
443 use super::*;
455 use serde_json::json;
456 use solo_core::VectorIndex;
457 use solo_storage::test_support::StubVectorIndex;
458 use solo_storage::{ReaderPool, StubEmbedder, WriterActor, WriterSpawn};
459 use std::sync::Arc as StdArc;
460
461 struct Harness {
462 server: SoloMcpServer,
463 _tmp: tempfile::TempDir,
464 write_handle_extra: Option<solo_storage::WriteHandle>,
465 join: Option<std::thread::JoinHandle<()>>,
466 }
467
468 impl Harness {
469 fn new(runtime: &tokio::runtime::Runtime) -> Self {
470 let tmp = tempfile::TempDir::new().unwrap();
471 let dim = 16usize;
472 let hnsw: StdArc<dyn VectorIndex + Send + Sync> = StdArc::new(StubVectorIndex::new(dim));
473 let embedder: StdArc<dyn solo_core::Embedder> = StdArc::new(StubEmbedder::new("stub", "v1", dim));
474
475 let conn = solo_storage::test_support::open_test_db_at(&tmp.path().join("test.db"));
476 let WriterSpawn { handle, join } = WriterActor::spawn(conn, hnsw.clone());
477
478 let path = tmp.path().join("test.db");
481 let pool: ReaderPool =
482 runtime.block_on(async { ReaderPool::new(&path, None, hnsw.clone()).unwrap() });
483
484 let server = SoloMcpServer::new(handle.clone(), pool, embedder, hnsw);
485 Harness {
486 server,
487 _tmp: tmp,
488 write_handle_extra: Some(handle),
489 join: Some(join),
490 }
491 }
492
493 fn shutdown(mut self, runtime: &tokio::runtime::Runtime) {
494 let join = self.join.take();
500 let extra = self.write_handle_extra.take();
501 runtime.block_on(async move {
502 drop(extra);
503 drop(self.server);
504 drop(self._tmp);
505 if let Some(join) = join {
506 let (tx, rx) = std::sync::mpsc::channel();
507 std::thread::spawn(move || {
508 let _ = tx.send(join.join());
509 });
510 tokio::task::spawn_blocking(move || {
511 rx.recv_timeout(std::time::Duration::from_secs(5))
512 })
513 .await
514 .expect("blocking task")
515 .expect("writer thread did not exit within 5s")
516 .expect("writer thread panicked");
517 }
518 });
519 }
520 }
521
522 fn rt() -> tokio::runtime::Runtime {
523 tokio::runtime::Builder::new_multi_thread()
524 .worker_threads(2)
525 .enable_all()
526 .build()
527 .unwrap()
528 }
529
530 fn first_text(r: &rmcp::model::CallToolResult) -> String {
535 let first = r.content.first().expect("at least one content item");
536 let v = serde_json::to_value(first).expect("content serialises");
537 v.get("text")
538 .and_then(|t| t.as_str())
539 .map(|s| s.to_string())
540 .unwrap_or_else(|| format!("{v}"))
541 }
542
543 #[test]
544 fn tools_list_returns_four_canonical_tools() {
545 let runtime = rt();
546 let h = Harness::new(&runtime);
547 let tools = h.server.dispatch_list_tools();
548 let names: Vec<&str> = tools.iter().map(|t| t.name.as_ref()).collect();
549 assert_eq!(
550 names,
551 vec![
552 "memory.remember",
553 "memory.recall",
554 "memory.forget",
555 "memory.inspect"
556 ]
557 );
558 for t in &tools {
559 assert!(!t.description.is_empty(), "{} description empty", t.name);
560 let schema = t.schema_as_json_value();
561 assert!(
562 schema.get("required").is_some(),
563 "{} missing 'required' field in input schema",
564 t.name
565 );
566 }
567 h.shutdown(&runtime);
568 }
569
570 #[test]
571 fn remember_then_recall_round_trip() {
572 let runtime = rt();
573 let h = Harness::new(&runtime);
574 runtime.block_on(async {
580 let r = h
581 .server
582 .dispatch_tool("memory.remember", json!({ "content": "the cat sat on the mat" }))
583 .await
584 .expect("remember succeeds");
585 let text = first_text(&r);
586 assert!(text.starts_with("remembered "), "got: {text}");
587
588 let r = h
589 .server
590 .dispatch_tool(
591 "memory.recall",
592 json!({ "query": "the cat sat on the mat", "limit": 5 }),
593 )
594 .await
595 .expect("recall succeeds");
596 let text = first_text(&r);
597 assert!(text.contains("the cat sat on the mat"), "got: {text}");
598 });
599 h.shutdown(&runtime);
600 }
601
602 #[test]
603 fn forget_excludes_row_from_subsequent_recall() {
604 let runtime = rt();
605 let h = Harness::new(&runtime);
606
607 runtime.block_on(async {
608 let r = h
609 .server
610 .dispatch_tool("memory.remember", json!({ "content": "to be forgotten" }))
611 .await
612 .unwrap();
613 let text = first_text(&r);
614 let mid = text.strip_prefix("remembered ").unwrap().to_string();
615
616 h.server
617 .dispatch_tool(
618 "memory.forget",
619 json!({ "memory_id": mid, "reason": "test" }),
620 )
621 .await
622 .expect("forget succeeds");
623
624 let r = h
625 .server
626 .dispatch_tool(
627 "memory.recall",
628 json!({ "query": "to be forgotten", "limit": 5 }),
629 )
630 .await
631 .unwrap();
632 let text = first_text(&r);
633 assert!(
634 !text.contains(r#""content": "to be forgotten""#),
635 "forgotten row should be excluded; got: {text}"
636 );
637 });
638 h.shutdown(&runtime);
639 }
640
641 #[test]
642 fn empty_remember_returns_invalid_params() {
643 let runtime = rt();
644 let h = Harness::new(&runtime);
645 runtime.block_on(async {
646 let err = h
647 .server
648 .dispatch_tool("memory.remember", json!({ "content": "" }))
649 .await
650 .unwrap_err();
651 assert!(format!("{err:?}").contains("must not be empty"));
652 });
653 h.shutdown(&runtime);
654 }
655
656 #[test]
657 fn empty_recall_query_returns_invalid_params() {
658 let runtime = rt();
659 let h = Harness::new(&runtime);
660 runtime.block_on(async {
661 let err = h
662 .server
663 .dispatch_tool("memory.recall", json!({ "query": " " }))
664 .await
665 .unwrap_err();
666 assert!(format!("{err:?}").contains("must not be empty"));
667 });
668 h.shutdown(&runtime);
669 }
670
671 #[test]
672 fn inspect_with_invalid_id_returns_invalid_params() {
673 let runtime = rt();
674 let h = Harness::new(&runtime);
675 runtime.block_on(async {
676 let err = h
677 .server
678 .dispatch_tool("memory.inspect", json!({ "memory_id": "not-a-uuid" }))
679 .await
680 .unwrap_err();
681 assert!(format!("{err:?}").contains("invalid memory_id"));
682 });
683 h.shutdown(&runtime);
684 }
685
686 #[test]
687 fn forget_unknown_id_returns_invalid_params() {
688 let runtime = rt();
689 let h = Harness::new(&runtime);
690 runtime.block_on(async {
691 let err = h
695 .server
696 .dispatch_tool(
697 "memory.forget",
698 json!({ "memory_id": "00000000-0000-7000-8000-000000000000" }),
699 )
700 .await
701 .unwrap_err();
702 assert!(format!("{err:?}").contains("not found"));
703 });
704 h.shutdown(&runtime);
705 }
706
707 #[test]
708 fn unknown_tool_name_returns_invalid_params() {
709 let runtime = rt();
710 let h = Harness::new(&runtime);
711 runtime.block_on(async {
712 let err = h
713 .server
714 .dispatch_tool("memory.summon", json!({}))
715 .await
716 .unwrap_err();
717 assert!(format!("{err:?}").contains("unknown tool"));
718 });
719 h.shutdown(&runtime);
720 }
721}
722
723