Skip to main content

solo_api/
mcp.rs

1// SPDX-License-Identifier: Apache-2.0
2
3//! MCP (Model Context Protocol) server for Solo.
4//!
5//! Exposes four tools to MCP clients (Claude Desktop, Cursor, etc.):
6//!
7//!   - `memory.remember(content, source_type?, source_id?)` — store an
8//!     episode. Returns the new MemoryId.
9//!   - `memory.recall(query, limit?)` — vector search. Returns the top-K
10//!     matches with content + tier + status.
11//!   - `memory.forget(memory_id, reason?)` — soft-delete an episode.
12//!   - `memory.inspect(memory_id)` — return the full episode record.
13//!
14//! ## Transport
15//!
16//! `serve_stdio` wires the server to stdin/stdout for use as a subprocess
17//! ("`claude_desktop_config.json` or `~/.cursor/mcp.json` invokes
18//! `solo mcp-stdio`"). The function awaits a graceful shutdown when stdin
19//! closes (parent disconnects) — same lifecycle as `solo daemon`'s
20//! Ctrl+C path.
21//!
22//! ## What's deferred
23//!
24//! - SSE/HTTP transports — `rmcp` ships them, but v0.1 ships stdio only.
25//! - `prompts/` and `resources/` capabilities — not needed for the
26//!   four-tool surface; ServerHandler defaults return empty lists.
27//! - Tool argument validation beyond JSON Schema typing — we trust rmcp
28//!   to deserialize per the schema, then serde-deserialize into our
29//!   typed param structs. Bad inputs surface as clear errors.
30
31use std::sync::Arc;
32
33use rmcp::handler::server::ServerHandler;
34use rmcp::model::{
35    CallToolRequestParam, CallToolResult, Content, Implementation, ListToolsResult,
36    PaginatedRequestParam, ProtocolVersion, ServerCapabilities, ServerInfo, Tool,
37    ToolsCapability,
38};
39use rmcp::service::{RequestContext, RoleServer};
40use rmcp::{Error as McpError, ServiceExt};
41use serde::{Deserialize, Serialize};
42use solo_core::{
43    Confidence, Embedder, EncodingContext, Episode, MemoryId, Tier,
44    VectorIndex,
45};
46use solo_storage::{ReaderPool, WriteHandle};
47use std::str::FromStr;
48
49/// The MCP server. Cheap to clone — every field is `Arc`-cloneable.
50#[derive(Clone)]
51pub struct SoloMcpServer {
52    inner: Arc<Inner>,
53}
54
55struct Inner {
56    write: WriteHandle,
57    pool: ReaderPool,
58    embedder: Arc<dyn Embedder>,
59    hnsw: Arc<dyn VectorIndex + Send + Sync>,
60}
61
62impl SoloMcpServer {
63    pub fn new(
64        write: WriteHandle,
65        pool: ReaderPool,
66        embedder: Arc<dyn Embedder>,
67        hnsw: Arc<dyn VectorIndex + Send + Sync>,
68    ) -> Self {
69        Self {
70            inner: Arc::new(Inner {
71                write,
72                pool,
73                embedder,
74                hnsw,
75            }),
76        }
77    }
78}
79
80/// Convenience: run the server over stdio and await its termination.
81/// Returns when stdin closes (parent disconnect) or the runtime exits.
82pub async fn serve_stdio(server: SoloMcpServer) -> anyhow::Result<()> {
83    use rmcp::transport::io::stdio;
84    let (stdin, stdout) = stdio();
85    let running = server.serve((stdin, stdout)).await?;
86    running.waiting().await?;
87    Ok(())
88}
89
90// ---------------------------------------------------------------------------
91// Tool argument schemas
92// ---------------------------------------------------------------------------
93
94#[derive(Debug, Clone, Serialize, Deserialize)]
95pub struct RememberArgs {
96    pub content: String,
97    #[serde(default)]
98    pub source_type: Option<String>,
99    #[serde(default)]
100    pub source_id: Option<String>,
101}
102
103#[derive(Debug, Clone, Serialize, Deserialize)]
104pub struct RecallArgs {
105    pub query: String,
106    #[serde(default = "default_limit")]
107    pub limit: usize,
108}
109
110fn default_limit() -> usize {
111    5
112}
113
114#[derive(Debug, Clone, Serialize, Deserialize)]
115pub struct ForgetArgs {
116    pub memory_id: String,
117    #[serde(default = "default_forget_reason")]
118    pub reason: String,
119}
120
121fn default_forget_reason() -> String {
122    "user-initiated via MCP".into()
123}
124
125#[derive(Debug, Clone, Serialize, Deserialize)]
126pub struct InspectArgs {
127    pub memory_id: String,
128}
129
130// ---------------------------------------------------------------------------
131// ServerHandler implementation
132// ---------------------------------------------------------------------------
133
134impl ServerHandler for SoloMcpServer {
135    fn get_info(&self) -> ServerInfo {
136        ServerInfo {
137            protocol_version: ProtocolVersion::default(),
138            capabilities: ServerCapabilities {
139                tools: Some(ToolsCapability {
140                    list_changed: Some(false),
141                }),
142                ..Default::default()
143            },
144            server_info: Implementation {
145                name: "solo".into(),
146                version: env!("CARGO_PKG_VERSION").into(),
147            },
148            instructions: Some(
149                "Solo: local-first personal memory for LLMs. Use \
150                 memory.remember to store, memory.recall to search, \
151                 memory.forget to soft-delete, and memory.inspect to \
152                 fetch a full record."
153                    .into(),
154            ),
155        }
156    }
157
158    async fn list_tools(
159        &self,
160        _request: PaginatedRequestParam,
161        _context: RequestContext<RoleServer>,
162    ) -> std::result::Result<ListToolsResult, McpError> {
163        Ok(ListToolsResult {
164            tools: build_tools(),
165            next_cursor: None,
166        })
167    }
168
169    async fn call_tool(
170        &self,
171        request: CallToolRequestParam,
172        _context: RequestContext<RoleServer>,
173    ) -> std::result::Result<CallToolResult, McpError> {
174        let CallToolRequestParam { name, arguments } = request;
175        let args_value = serde_json::Value::Object(arguments.unwrap_or_default());
176        self.dispatch_tool(&name, args_value).await
177    }
178}
179
180impl SoloMcpServer {
181    /// Direct tool-dispatch path used by both `call_tool` (the
182    /// ServerHandler trait method, behind the rmcp protocol layer) and
183    /// in-process tests that don't want to spin up a full transport pair.
184    /// Bypasses `RequestContext` (which requires a `Peer` not constructible
185    /// outside rmcp internals).
186    pub async fn dispatch_tool(
187        &self,
188        name: &str,
189        args_value: serde_json::Value,
190    ) -> std::result::Result<CallToolResult, McpError> {
191        match name {
192            "memory.remember" => {
193                let args: RememberArgs = parse_args(&args_value)?;
194                self.handle_remember(args).await
195            }
196            "memory.recall" => {
197                let args: RecallArgs = parse_args(&args_value)?;
198                self.handle_recall(args).await
199            }
200            "memory.forget" => {
201                let args: ForgetArgs = parse_args(&args_value)?;
202                self.handle_forget(args).await
203            }
204            "memory.inspect" => {
205                let args: InspectArgs = parse_args(&args_value)?;
206                self.handle_inspect(args).await
207            }
208            other => Err(McpError::invalid_params(
209                format!("unknown tool `{other}`"),
210                None,
211            )),
212        }
213    }
214
215    /// List the tools this server exposes. Mirrors `ServerHandler::list_tools`
216    /// without requiring a RequestContext.
217    pub fn dispatch_list_tools(&self) -> Vec<Tool> {
218        build_tools()
219    }
220}
221
222fn parse_args<T: serde::de::DeserializeOwned>(
223    v: &serde_json::Value,
224) -> std::result::Result<T, McpError> {
225    serde_json::from_value(v.clone()).map_err(|e| {
226        McpError::invalid_params(format!("invalid tool arguments: {e}"), None)
227    })
228}
229
230fn solo_to_mcp(e: solo_core::Error) -> McpError {
231    use solo_core::Error;
232    match e {
233        Error::NotFound(msg) => McpError::invalid_params(msg, None),
234        Error::InvalidInput(msg) => McpError::invalid_params(msg, None),
235        Error::Conflict(msg) => McpError::invalid_params(msg, None),
236        other => McpError::internal_error(other.to_string(), None),
237    }
238}
239
240// ---------------------------------------------------------------------------
241// Tool definitions (JSON Schema)
242// ---------------------------------------------------------------------------
243
244fn build_tools() -> Vec<Tool> {
245    vec![
246        Tool::new(
247            "memory.remember",
248            "Store a new episodic memory. Returns the new MemoryId (UUID v7).",
249            json_schema_object(serde_json::json!({
250                "type": "object",
251                "properties": {
252                    "content": {
253                        "type": "string",
254                        "description": "The text to remember.",
255                    },
256                    "source_type": {
257                        "type": "string",
258                        "description": "Optional source-type tag (default: \"user_message\").",
259                    },
260                    "source_id": {
261                        "type": "string",
262                        "description": "Optional upstream id for traceability.",
263                    },
264                },
265                "required": ["content"],
266            })),
267        ),
268        Tool::new(
269            "memory.recall",
270            "Vector-search the memory store. Returns up to `limit` results \
271             ordered by cosine distance (smaller = more similar). Excludes \
272             forgotten memories.",
273            json_schema_object(serde_json::json!({
274                "type": "object",
275                "properties": {
276                    "query": {
277                        "type": "string",
278                        "description": "The query text.",
279                    },
280                    "limit": {
281                        "type": "integer",
282                        "description": "Maximum results (default 5).",
283                        "minimum": 1,
284                        "maximum": 100,
285                    },
286                },
287                "required": ["query"],
288            })),
289        ),
290        Tool::new(
291            "memory.forget",
292            "Soft-delete a memory by id. The HNSW vector stays in the graph \
293             but the SQL row's status flips to 'forgotten' so future recalls \
294             exclude it.",
295            json_schema_object(serde_json::json!({
296                "type": "object",
297                "properties": {
298                    "memory_id": {
299                        "type": "string",
300                        "description": "MemoryId to forget (UUID v7).",
301                    },
302                    "reason": {
303                        "type": "string",
304                        "description": "Optional free-form reason (logged, not yet persisted).",
305                    },
306                },
307                "required": ["memory_id"],
308            })),
309        ),
310        Tool::new(
311            "memory.inspect",
312            "Return the full record for a memory_id (timestamps, source, \
313             status, scoring values, content).",
314            json_schema_object(serde_json::json!({
315                "type": "object",
316                "properties": {
317                    "memory_id": {
318                        "type": "string",
319                        "description": "MemoryId to inspect (UUID v7).",
320                    },
321                },
322                "required": ["memory_id"],
323            })),
324        ),
325    ]
326}
327
328fn json_schema_object(value: serde_json::Value) -> serde_json::Map<String, serde_json::Value> {
329    match value {
330        serde_json::Value::Object(map) => map,
331        _ => panic!("json_schema_object: input must be an object"),
332    }
333}
334
335// ---------------------------------------------------------------------------
336// Tool handlers
337// ---------------------------------------------------------------------------
338
339impl SoloMcpServer {
340    async fn handle_remember(
341        &self,
342        args: RememberArgs,
343    ) -> std::result::Result<CallToolResult, McpError> {
344        let content = args.content.trim_end().to_string();
345        if content.is_empty() {
346            return Err(McpError::invalid_params(
347                "memory.remember: content must not be empty".to_string(),
348                None,
349            ));
350        }
351        let embedding: solo_core::Embedding = self
352            .inner
353            .embedder
354            .embed(&content)
355            .await
356            .map_err(solo_to_mcp)?;
357        let episode = Episode {
358            memory_id: MemoryId::new(),
359            ts_ms: chrono::Utc::now().timestamp_millis(),
360            source_type: args.source_type.unwrap_or_else(|| "user_message".into()),
361            source_id: args.source_id,
362            content,
363            encoding_context: EncodingContext::default(),
364            provenance: None,
365            confidence: Confidence::new(0.9).unwrap(),
366            strength: 0.5,
367            salience: 0.5,
368            tier: Tier::Hot,
369        };
370        let mid = self
371            .inner
372            .write
373            .remember(episode, embedding)
374            .await
375            .map_err(solo_to_mcp)?;
376        Ok(CallToolResult::success(vec![Content::text(format!(
377            "remembered {mid}"
378        ))]))
379    }
380
381    async fn handle_recall(
382        &self,
383        args: RecallArgs,
384    ) -> std::result::Result<CallToolResult, McpError> {
385        // Pipeline lives in solo-query; the transport just formats the
386        // result. solo_query::run_recall validates empty queries
387        // (returns InvalidInput → invalid_params via solo_to_mcp).
388        let result = solo_query::run_recall(
389            &self.inner.embedder,
390            &self.inner.hnsw,
391            &self.inner.pool,
392            &args.query,
393            args.limit,
394        )
395        .await
396        .map_err(solo_to_mcp)?;
397
398        if result.hits.is_empty() {
399            return Ok(CallToolResult::success(vec![Content::text(format!(
400                "no matches (index has {} vectors)",
401                result.index_len
402            ))]));
403        }
404        let body = serde_json::to_string_pretty(&result.hits).unwrap_or_else(|_| String::new());
405        Ok(CallToolResult::success(vec![Content::text(body)]))
406    }
407
408    async fn handle_forget(
409        &self,
410        args: ForgetArgs,
411    ) -> std::result::Result<CallToolResult, McpError> {
412        let mid = MemoryId::from_str(&args.memory_id).map_err(|e| {
413            McpError::invalid_params(format!("invalid memory_id: {e}"), None)
414        })?;
415        self.inner
416            .write
417            .forget(mid, args.reason)
418            .await
419            .map_err(solo_to_mcp)?;
420        Ok(CallToolResult::success(vec![Content::text(format!(
421            "forgotten {mid}"
422        ))]))
423    }
424
425    async fn handle_inspect(
426        &self,
427        args: InspectArgs,
428    ) -> std::result::Result<CallToolResult, McpError> {
429        let mid = MemoryId::from_str(&args.memory_id).map_err(|e| {
430            McpError::invalid_params(format!("invalid memory_id: {e}"), None)
431        })?;
432        // Pipeline lives in solo-query::inspect; transports just format.
433        let row = solo_query::inspect_one(&self.inner.pool, mid)
434            .await
435            .map_err(solo_to_mcp)?;
436        let body = serde_json::to_string_pretty(&row).unwrap_or_else(|_| String::new());
437        Ok(CallToolResult::success(vec![Content::text(body)]))
438    }
439}
440
441#[cfg(test)]
442mod dispatch_tests {
443    //! In-process integration tests for the MCP tool surface. We invoke
444    //! `SoloMcpServer::dispatch_tool` directly (bypasses the rmcp
445    //! protocol framing + `RequestContext`, which requires a `Peer`
446    //! that's not constructible outside rmcp internals). The server is
447    //! constructed against a real WriterActor + ReaderPool +
448    //! StubEmbedder + StubVectorIndex from `solo_storage::test_support`.
449    //!
450    //! Tests live inline in this module rather than `tests/` because an
451    //! external integration-test exe in `target/debug/deps/mcp_dispatch-*`
452    //! tripped Windows UAC ERROR_ELEVATION_REQUIRED on the dev machine.
453    //! The lib test binary doesn't have that issue.
454    use super::*;
455    use serde_json::json;
456    use solo_core::VectorIndex;
457    use solo_storage::test_support::StubVectorIndex;
458    use solo_storage::{ReaderPool, StubEmbedder, WriterActor, WriterSpawn};
459    use std::sync::Arc as StdArc;
460
461    struct Harness {
462        server: SoloMcpServer,
463        _tmp: tempfile::TempDir,
464        write_handle_extra: Option<solo_storage::WriteHandle>,
465        join: Option<std::thread::JoinHandle<()>>,
466    }
467
468    impl Harness {
469        fn new(runtime: &tokio::runtime::Runtime) -> Self {
470            let tmp = tempfile::TempDir::new().unwrap();
471            let dim = 16usize;
472            let hnsw: StdArc<dyn VectorIndex + Send + Sync> = StdArc::new(StubVectorIndex::new(dim));
473            let embedder: StdArc<dyn solo_core::Embedder> = StdArc::new(StubEmbedder::new("stub", "v1", dim));
474
475            let conn = solo_storage::test_support::open_test_db_at(&tmp.path().join("test.db"));
476            let WriterSpawn { handle, join } = WriterActor::spawn(conn, hnsw.clone());
477
478            // ReaderPool's deadpool::Pool needs a live tokio runtime for
479            // both build + drop; build inside block_on.
480            let path = tmp.path().join("test.db");
481            let pool: ReaderPool =
482                runtime.block_on(async { ReaderPool::new(&path, None, hnsw.clone()).unwrap() });
483
484            let server = SoloMcpServer::new(handle.clone(), pool, embedder, hnsw);
485            Harness {
486                server,
487                _tmp: tmp,
488                write_handle_extra: Some(handle),
489                join: Some(join),
490            }
491        }
492
493        fn shutdown(mut self, runtime: &tokio::runtime::Runtime) {
494            // The whole shutdown runs inside block_on so deadpool-sqlite's
495            // drop (which schedules cleanup on the active runtime) sees a
496            // live reactor. Without this, dropping the SoloMcpServer
497            // (which holds the ReaderPool through its Arc<Inner>) panics
498            // with "no reactor running".
499            let join = self.join.take();
500            let extra = self.write_handle_extra.take();
501            runtime.block_on(async move {
502                drop(extra);
503                drop(self.server);
504                drop(self._tmp);
505                if let Some(join) = join {
506                    let (tx, rx) = std::sync::mpsc::channel();
507                    std::thread::spawn(move || {
508                        let _ = tx.send(join.join());
509                    });
510                    tokio::task::spawn_blocking(move || {
511                        rx.recv_timeout(std::time::Duration::from_secs(5))
512                    })
513                    .await
514                    .expect("blocking task")
515                    .expect("writer thread did not exit within 5s")
516                    .expect("writer thread panicked");
517                }
518            });
519        }
520    }
521
522    fn rt() -> tokio::runtime::Runtime {
523        tokio::runtime::Builder::new_multi_thread()
524            .worker_threads(2)
525            .enable_all()
526            .build()
527            .unwrap()
528    }
529
530    /// Pull the first Content::text body out of a CallToolResult. Use
531    /// serde_json roundtrip as a robust extractor — `Content`'s public
532    /// API doesn't directly expose the inner text without going through
533    /// pattern-matching on RawContent.
534    fn first_text(r: &rmcp::model::CallToolResult) -> String {
535        let first = r.content.first().expect("at least one content item");
536        let v = serde_json::to_value(first).expect("content serialises");
537        v.get("text")
538            .and_then(|t| t.as_str())
539            .map(|s| s.to_string())
540            .unwrap_or_else(|| format!("{v}"))
541    }
542
543    #[test]
544    fn tools_list_returns_four_canonical_tools() {
545        let runtime = rt();
546        let h = Harness::new(&runtime);
547        let tools = h.server.dispatch_list_tools();
548        let names: Vec<&str> = tools.iter().map(|t| t.name.as_ref()).collect();
549        assert_eq!(
550            names,
551            vec![
552                "memory.remember",
553                "memory.recall",
554                "memory.forget",
555                "memory.inspect"
556            ]
557        );
558        for t in &tools {
559            assert!(!t.description.is_empty(), "{} description empty", t.name);
560            let schema = t.schema_as_json_value();
561            assert!(
562                schema.get("required").is_some(),
563                "{} missing 'required' field in input schema",
564                t.name
565            );
566        }
567        h.shutdown(&runtime);
568    }
569
570    #[test]
571    fn remember_then_recall_round_trip() {
572        let runtime = rt();
573        let h = Harness::new(&runtime);
574        // Use &h.server directly (no clone) so the only outstanding
575        // reference at shutdown time is the harness's own. The clone
576        // path triggered a 5-second writer-thread timeout because the
577        // local clone held an Arc<Inner> with its own WriteHandle past
578        // h.shutdown().
579        runtime.block_on(async {
580            let r = h
581                .server
582                .dispatch_tool("memory.remember", json!({ "content": "the cat sat on the mat" }))
583                .await
584                .expect("remember succeeds");
585            let text = first_text(&r);
586            assert!(text.starts_with("remembered "), "got: {text}");
587
588            let r = h
589                .server
590                .dispatch_tool(
591                    "memory.recall",
592                    json!({ "query": "the cat sat on the mat", "limit": 5 }),
593                )
594                .await
595                .expect("recall succeeds");
596            let text = first_text(&r);
597            assert!(text.contains("the cat sat on the mat"), "got: {text}");
598        });
599        h.shutdown(&runtime);
600    }
601
602    #[test]
603    fn forget_excludes_row_from_subsequent_recall() {
604        let runtime = rt();
605        let h = Harness::new(&runtime);
606
607        runtime.block_on(async {
608            let r = h
609                .server
610                .dispatch_tool("memory.remember", json!({ "content": "to be forgotten" }))
611                .await
612                .unwrap();
613            let text = first_text(&r);
614            let mid = text.strip_prefix("remembered ").unwrap().to_string();
615
616            h.server
617                .dispatch_tool(
618                    "memory.forget",
619                    json!({ "memory_id": mid, "reason": "test" }),
620                )
621                .await
622                .expect("forget succeeds");
623
624            let r = h
625                .server
626                .dispatch_tool(
627                    "memory.recall",
628                    json!({ "query": "to be forgotten", "limit": 5 }),
629                )
630                .await
631                .unwrap();
632            let text = first_text(&r);
633            assert!(
634                !text.contains(r#""content": "to be forgotten""#),
635                "forgotten row should be excluded; got: {text}"
636            );
637        });
638        h.shutdown(&runtime);
639    }
640
641    #[test]
642    fn empty_remember_returns_invalid_params() {
643        let runtime = rt();
644        let h = Harness::new(&runtime);
645        runtime.block_on(async {
646            let err = h
647                .server
648                .dispatch_tool("memory.remember", json!({ "content": "" }))
649                .await
650                .unwrap_err();
651            assert!(format!("{err:?}").contains("must not be empty"));
652        });
653        h.shutdown(&runtime);
654    }
655
656    #[test]
657    fn empty_recall_query_returns_invalid_params() {
658        let runtime = rt();
659        let h = Harness::new(&runtime);
660        runtime.block_on(async {
661            let err = h
662                .server
663                .dispatch_tool("memory.recall", json!({ "query": "   " }))
664                .await
665                .unwrap_err();
666            assert!(format!("{err:?}").contains("must not be empty"));
667        });
668        h.shutdown(&runtime);
669    }
670
671    #[test]
672    fn inspect_with_invalid_id_returns_invalid_params() {
673        let runtime = rt();
674        let h = Harness::new(&runtime);
675        runtime.block_on(async {
676            let err = h
677                .server
678                .dispatch_tool("memory.inspect", json!({ "memory_id": "not-a-uuid" }))
679                .await
680                .unwrap_err();
681            assert!(format!("{err:?}").contains("invalid memory_id"));
682        });
683        h.shutdown(&runtime);
684    }
685
686    #[test]
687    fn forget_unknown_id_returns_invalid_params() {
688        let runtime = rt();
689        let h = Harness::new(&runtime);
690        runtime.block_on(async {
691            // Valid UUID format but not in episodes — handle_forget
692            // surfaces NotFound, mapped to invalid_params per
693            // solo_to_mcp.
694            let err = h
695                .server
696                .dispatch_tool(
697                    "memory.forget",
698                    json!({ "memory_id": "00000000-0000-7000-8000-000000000000" }),
699                )
700                .await
701                .unwrap_err();
702            assert!(format!("{err:?}").contains("not found"));
703        });
704        h.shutdown(&runtime);
705    }
706
707    #[test]
708    fn unknown_tool_name_returns_invalid_params() {
709        let runtime = rt();
710        let h = Harness::new(&runtime);
711        runtime.block_on(async {
712            let err = h
713                .server
714                .dispatch_tool("memory.summon", json!({}))
715                .await
716                .unwrap_err();
717            assert!(format!("{err:?}").contains("unknown tool"));
718        });
719        h.shutdown(&runtime);
720    }
721}
722
723// fetch_recall_rows + RecallHit + RecallRow used to live here. Recall
724// pipeline moved to solo_query::recall in commit (consolidate-recall);
725// transports just call solo_query::run_recall and format the result.