use std::sync::Arc;
use rmcp::handler::server::ServerHandler;
use rmcp::model::{
CallToolRequestParam, CallToolResult, Content, Implementation, ListToolsResult,
PaginatedRequestParam, ProtocolVersion, ServerCapabilities, ServerInfo, Tool,
ToolsCapability,
};
use rmcp::service::{RequestContext, RoleServer};
use rmcp::{Error as McpError, ServiceExt};
use serde::{Deserialize, Serialize};
use solo_core::{
Confidence, Embedder, EncodingContext, Episode, MemoryId, Tier,
VectorIndex,
};
use solo_storage::{ReaderPool, WriteHandle};
use std::str::FromStr;
#[derive(Clone)]
pub struct SoloMcpServer {
inner: Arc<Inner>,
}
struct Inner {
write: WriteHandle,
pool: ReaderPool,
embedder: Arc<dyn Embedder>,
hnsw: Arc<dyn VectorIndex + Send + Sync>,
}
impl SoloMcpServer {
pub fn new(
write: WriteHandle,
pool: ReaderPool,
embedder: Arc<dyn Embedder>,
hnsw: Arc<dyn VectorIndex + Send + Sync>,
) -> Self {
Self {
inner: Arc::new(Inner {
write,
pool,
embedder,
hnsw,
}),
}
}
}
pub async fn serve_stdio(server: SoloMcpServer) -> anyhow::Result<()> {
use rmcp::transport::io::stdio;
let (stdin, stdout) = stdio();
let running = server.serve((stdin, stdout)).await?;
running.waiting().await?;
Ok(())
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RememberArgs {
pub content: String,
#[serde(default)]
pub source_type: Option<String>,
#[serde(default)]
pub source_id: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RecallArgs {
pub query: String,
#[serde(default = "default_limit")]
pub limit: usize,
}
fn default_limit() -> usize {
5
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ForgetArgs {
pub memory_id: String,
#[serde(default = "default_forget_reason")]
pub reason: String,
}
fn default_forget_reason() -> String {
"user-initiated via MCP".into()
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct InspectArgs {
pub memory_id: String,
}
impl ServerHandler for SoloMcpServer {
fn get_info(&self) -> ServerInfo {
ServerInfo {
protocol_version: ProtocolVersion::default(),
capabilities: ServerCapabilities {
tools: Some(ToolsCapability {
list_changed: Some(false),
}),
..Default::default()
},
server_info: Implementation {
name: "solo".into(),
version: env!("CARGO_PKG_VERSION").into(),
},
instructions: Some(
"Solo: local-first personal memory for LLMs. Use \
memory.remember to store, memory.recall to search, \
memory.forget to soft-delete, and memory.inspect to \
fetch a full record."
.into(),
),
}
}
async fn list_tools(
&self,
_request: PaginatedRequestParam,
_context: RequestContext<RoleServer>,
) -> std::result::Result<ListToolsResult, McpError> {
Ok(ListToolsResult {
tools: build_tools(),
next_cursor: None,
})
}
async fn call_tool(
&self,
request: CallToolRequestParam,
_context: RequestContext<RoleServer>,
) -> std::result::Result<CallToolResult, McpError> {
let CallToolRequestParam { name, arguments } = request;
let args_value = serde_json::Value::Object(arguments.unwrap_or_default());
self.dispatch_tool(&name, args_value).await
}
}
impl SoloMcpServer {
pub async fn dispatch_tool(
&self,
name: &str,
args_value: serde_json::Value,
) -> std::result::Result<CallToolResult, McpError> {
match name {
"memory.remember" => {
let args: RememberArgs = parse_args(&args_value)?;
self.handle_remember(args).await
}
"memory.recall" => {
let args: RecallArgs = parse_args(&args_value)?;
self.handle_recall(args).await
}
"memory.forget" => {
let args: ForgetArgs = parse_args(&args_value)?;
self.handle_forget(args).await
}
"memory.inspect" => {
let args: InspectArgs = parse_args(&args_value)?;
self.handle_inspect(args).await
}
other => Err(McpError::invalid_params(
format!("unknown tool `{other}`"),
None,
)),
}
}
pub fn dispatch_list_tools(&self) -> Vec<Tool> {
build_tools()
}
}
fn parse_args<T: serde::de::DeserializeOwned>(
v: &serde_json::Value,
) -> std::result::Result<T, McpError> {
serde_json::from_value(v.clone()).map_err(|e| {
McpError::invalid_params(format!("invalid tool arguments: {e}"), None)
})
}
fn solo_to_mcp(e: solo_core::Error) -> McpError {
use solo_core::Error;
match e {
Error::NotFound(msg) => McpError::invalid_params(msg, None),
Error::InvalidInput(msg) => McpError::invalid_params(msg, None),
Error::Conflict(msg) => McpError::invalid_params(msg, None),
other => McpError::internal_error(other.to_string(), None),
}
}
fn build_tools() -> Vec<Tool> {
vec![
Tool::new(
"memory.remember",
"Store a new episodic memory. Returns the new MemoryId (UUID v7).",
json_schema_object(serde_json::json!({
"type": "object",
"properties": {
"content": {
"type": "string",
"description": "The text to remember.",
},
"source_type": {
"type": "string",
"description": "Optional source-type tag (default: \"user_message\").",
},
"source_id": {
"type": "string",
"description": "Optional upstream id for traceability.",
},
},
"required": ["content"],
})),
),
Tool::new(
"memory.recall",
"Vector-search the memory store. Returns up to `limit` results \
ordered by cosine distance (smaller = more similar). Excludes \
forgotten memories.",
json_schema_object(serde_json::json!({
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "The query text.",
},
"limit": {
"type": "integer",
"description": "Maximum results (default 5).",
"minimum": 1,
"maximum": 100,
},
},
"required": ["query"],
})),
),
Tool::new(
"memory.forget",
"Soft-delete a memory by id. The HNSW vector stays in the graph \
but the SQL row's status flips to 'forgotten' so future recalls \
exclude it.",
json_schema_object(serde_json::json!({
"type": "object",
"properties": {
"memory_id": {
"type": "string",
"description": "MemoryId to forget (UUID v7).",
},
"reason": {
"type": "string",
"description": "Optional free-form reason (logged, not yet persisted).",
},
},
"required": ["memory_id"],
})),
),
Tool::new(
"memory.inspect",
"Return the full record for a memory_id (timestamps, source, \
status, scoring values, content).",
json_schema_object(serde_json::json!({
"type": "object",
"properties": {
"memory_id": {
"type": "string",
"description": "MemoryId to inspect (UUID v7).",
},
},
"required": ["memory_id"],
})),
),
]
}
fn json_schema_object(value: serde_json::Value) -> serde_json::Map<String, serde_json::Value> {
match value {
serde_json::Value::Object(map) => map,
_ => panic!("json_schema_object: input must be an object"),
}
}
impl SoloMcpServer {
async fn handle_remember(
&self,
args: RememberArgs,
) -> std::result::Result<CallToolResult, McpError> {
let content = args.content.trim_end().to_string();
if content.is_empty() {
return Err(McpError::invalid_params(
"memory.remember: content must not be empty".to_string(),
None,
));
}
let embedding: solo_core::Embedding = self
.inner
.embedder
.embed(&content)
.await
.map_err(solo_to_mcp)?;
let episode = Episode {
memory_id: MemoryId::new(),
ts_ms: chrono::Utc::now().timestamp_millis(),
source_type: args.source_type.unwrap_or_else(|| "user_message".into()),
source_id: args.source_id,
content,
encoding_context: EncodingContext::default(),
provenance: None,
confidence: Confidence::new(0.9).unwrap(),
strength: 0.5,
salience: 0.5,
tier: Tier::Hot,
};
let mid = self
.inner
.write
.remember(episode, embedding)
.await
.map_err(solo_to_mcp)?;
Ok(CallToolResult::success(vec![Content::text(format!(
"remembered {mid}"
))]))
}
async fn handle_recall(
&self,
args: RecallArgs,
) -> std::result::Result<CallToolResult, McpError> {
let result = solo_query::run_recall(
&self.inner.embedder,
&self.inner.hnsw,
&self.inner.pool,
&args.query,
args.limit,
)
.await
.map_err(solo_to_mcp)?;
if result.hits.is_empty() {
return Ok(CallToolResult::success(vec![Content::text(format!(
"no matches (index has {} vectors)",
result.index_len
))]));
}
let body = serde_json::to_string_pretty(&result.hits).unwrap_or_else(|_| String::new());
Ok(CallToolResult::success(vec![Content::text(body)]))
}
async fn handle_forget(
&self,
args: ForgetArgs,
) -> std::result::Result<CallToolResult, McpError> {
let mid = MemoryId::from_str(&args.memory_id).map_err(|e| {
McpError::invalid_params(format!("invalid memory_id: {e}"), None)
})?;
self.inner
.write
.forget(mid, args.reason)
.await
.map_err(solo_to_mcp)?;
Ok(CallToolResult::success(vec![Content::text(format!(
"forgotten {mid}"
))]))
}
async fn handle_inspect(
&self,
args: InspectArgs,
) -> std::result::Result<CallToolResult, McpError> {
let mid = MemoryId::from_str(&args.memory_id).map_err(|e| {
McpError::invalid_params(format!("invalid memory_id: {e}"), None)
})?;
let row = solo_query::inspect_one(&self.inner.pool, mid)
.await
.map_err(solo_to_mcp)?;
let body = serde_json::to_string_pretty(&row).unwrap_or_else(|_| String::new());
Ok(CallToolResult::success(vec![Content::text(body)]))
}
}
#[cfg(test)]
mod dispatch_tests {
use super::*;
use serde_json::json;
use solo_core::VectorIndex;
use solo_storage::test_support::StubVectorIndex;
use solo_storage::{ReaderPool, StubEmbedder, WriterActor, WriterSpawn};
use std::sync::Arc as StdArc;
struct Harness {
server: SoloMcpServer,
_tmp: tempfile::TempDir,
write_handle_extra: Option<solo_storage::WriteHandle>,
join: Option<std::thread::JoinHandle<()>>,
}
impl Harness {
fn new(runtime: &tokio::runtime::Runtime) -> Self {
let tmp = tempfile::TempDir::new().unwrap();
let dim = 16usize;
let hnsw: StdArc<dyn VectorIndex + Send + Sync> = StdArc::new(StubVectorIndex::new(dim));
let embedder: StdArc<dyn solo_core::Embedder> = StdArc::new(StubEmbedder::new("stub", "v1", dim));
let conn = solo_storage::test_support::open_test_db_at(&tmp.path().join("test.db"));
let WriterSpawn { handle, join } = WriterActor::spawn(conn, hnsw.clone());
let path = tmp.path().join("test.db");
let pool: ReaderPool =
runtime.block_on(async { ReaderPool::new(&path, None, hnsw.clone()).unwrap() });
let server = SoloMcpServer::new(handle.clone(), pool, embedder, hnsw);
Harness {
server,
_tmp: tmp,
write_handle_extra: Some(handle),
join: Some(join),
}
}
fn shutdown(mut self, runtime: &tokio::runtime::Runtime) {
let join = self.join.take();
let extra = self.write_handle_extra.take();
runtime.block_on(async move {
drop(extra);
drop(self.server);
drop(self._tmp);
if let Some(join) = join {
let (tx, rx) = std::sync::mpsc::channel();
std::thread::spawn(move || {
let _ = tx.send(join.join());
});
tokio::task::spawn_blocking(move || {
rx.recv_timeout(std::time::Duration::from_secs(5))
})
.await
.expect("blocking task")
.expect("writer thread did not exit within 5s")
.expect("writer thread panicked");
}
});
}
}
fn rt() -> tokio::runtime::Runtime {
tokio::runtime::Builder::new_multi_thread()
.worker_threads(2)
.enable_all()
.build()
.unwrap()
}
fn first_text(r: &rmcp::model::CallToolResult) -> String {
let first = r.content.first().expect("at least one content item");
let v = serde_json::to_value(first).expect("content serialises");
v.get("text")
.and_then(|t| t.as_str())
.map(|s| s.to_string())
.unwrap_or_else(|| format!("{v}"))
}
#[test]
fn tools_list_returns_four_canonical_tools() {
let runtime = rt();
let h = Harness::new(&runtime);
let tools = h.server.dispatch_list_tools();
let names: Vec<&str> = tools.iter().map(|t| t.name.as_ref()).collect();
assert_eq!(
names,
vec![
"memory.remember",
"memory.recall",
"memory.forget",
"memory.inspect"
]
);
for t in &tools {
assert!(!t.description.is_empty(), "{} description empty", t.name);
let schema = t.schema_as_json_value();
assert!(
schema.get("required").is_some(),
"{} missing 'required' field in input schema",
t.name
);
}
h.shutdown(&runtime);
}
#[test]
fn remember_then_recall_round_trip() {
let runtime = rt();
let h = Harness::new(&runtime);
runtime.block_on(async {
let r = h
.server
.dispatch_tool("memory.remember", json!({ "content": "the cat sat on the mat" }))
.await
.expect("remember succeeds");
let text = first_text(&r);
assert!(text.starts_with("remembered "), "got: {text}");
let r = h
.server
.dispatch_tool(
"memory.recall",
json!({ "query": "the cat sat on the mat", "limit": 5 }),
)
.await
.expect("recall succeeds");
let text = first_text(&r);
assert!(text.contains("the cat sat on the mat"), "got: {text}");
});
h.shutdown(&runtime);
}
#[test]
fn forget_excludes_row_from_subsequent_recall() {
let runtime = rt();
let h = Harness::new(&runtime);
runtime.block_on(async {
let r = h
.server
.dispatch_tool("memory.remember", json!({ "content": "to be forgotten" }))
.await
.unwrap();
let text = first_text(&r);
let mid = text.strip_prefix("remembered ").unwrap().to_string();
h.server
.dispatch_tool(
"memory.forget",
json!({ "memory_id": mid, "reason": "test" }),
)
.await
.expect("forget succeeds");
let r = h
.server
.dispatch_tool(
"memory.recall",
json!({ "query": "to be forgotten", "limit": 5 }),
)
.await
.unwrap();
let text = first_text(&r);
assert!(
!text.contains(r#""content": "to be forgotten""#),
"forgotten row should be excluded; got: {text}"
);
});
h.shutdown(&runtime);
}
#[test]
fn empty_remember_returns_invalid_params() {
let runtime = rt();
let h = Harness::new(&runtime);
runtime.block_on(async {
let err = h
.server
.dispatch_tool("memory.remember", json!({ "content": "" }))
.await
.unwrap_err();
assert!(format!("{err:?}").contains("must not be empty"));
});
h.shutdown(&runtime);
}
#[test]
fn empty_recall_query_returns_invalid_params() {
let runtime = rt();
let h = Harness::new(&runtime);
runtime.block_on(async {
let err = h
.server
.dispatch_tool("memory.recall", json!({ "query": " " }))
.await
.unwrap_err();
assert!(format!("{err:?}").contains("must not be empty"));
});
h.shutdown(&runtime);
}
#[test]
fn inspect_with_invalid_id_returns_invalid_params() {
let runtime = rt();
let h = Harness::new(&runtime);
runtime.block_on(async {
let err = h
.server
.dispatch_tool("memory.inspect", json!({ "memory_id": "not-a-uuid" }))
.await
.unwrap_err();
assert!(format!("{err:?}").contains("invalid memory_id"));
});
h.shutdown(&runtime);
}
#[test]
fn forget_unknown_id_returns_invalid_params() {
let runtime = rt();
let h = Harness::new(&runtime);
runtime.block_on(async {
let err = h
.server
.dispatch_tool(
"memory.forget",
json!({ "memory_id": "00000000-0000-7000-8000-000000000000" }),
)
.await
.unwrap_err();
assert!(format!("{err:?}").contains("not found"));
});
h.shutdown(&runtime);
}
#[test]
fn unknown_tool_name_returns_invalid_params() {
let runtime = rt();
let h = Harness::new(&runtime);
runtime.block_on(async {
let err = h
.server
.dispatch_tool("memory.summon", json!({}))
.await
.unwrap_err();
assert!(format!("{err:?}").contains("unknown tool"));
});
h.shutdown(&runtime);
}
}