use crate::errors::Error;
use crate::mcp::params::*;
use crate::memory::MemoryStore;
use crate::memory::lifecycle::{MemoryStatus, MemoryType};
use crate::memory_types::ConflictMemory as InternalConflictMemory;
use crate::memory_types::IngestPolicy;
use rmcp::handler::server::tool::ToolRouter;
use rmcp::handler::server::wrapper::Parameters;
use rmcp::model::{CallToolResult, Content, ServerCapabilities, ServerInfo};
use rmcp::tool;
use rmcp::tool_handler;
use rmcp::tool_router;
use std::sync::{Arc, Mutex};
pub(crate) struct StoreWrapper(Arc<Mutex<MemoryStore>>);
impl StoreWrapper {
#[allow(dead_code)] pub(crate) fn new(store: Arc<Mutex<MemoryStore>>) -> Self {
Self(store)
}
#[allow(dead_code)] pub(crate) fn ingest(
&self,
project_id: &str,
text: &str,
metadata: &str,
force: bool,
) -> Result<serde_json::Value, rmcp::ErrorData> {
let mut store = self.0.lock().unwrap();
let policy = if force {
IngestPolicy::Force
} else {
IngestPolicy::ConflictAware
};
match store
.ingest(project_id, text, Some(metadata), policy)
.map_err(|e| -> rmcp::ErrorData { e.into() })?
{
crate::memory_types::AddResult::Added { id } => {
Ok(serde_json::to_value(SuccessResponse {
id,
status: "added".to_string(),
})
.map_err(McpError::from_serde_error)?)
}
crate::memory_types::AddResult::Conflicts {
proposed,
conflicts,
} => Err(McpError::conflict(
&proposed,
conflicts
.into_iter()
.map(|c: InternalConflictMemory| ConflictMemory {
id: c.id,
content: c.content,
similarity: c.similarity,
})
.collect(),
)),
}
}
#[allow(dead_code)] pub(crate) fn ingest_with_type_status(
&self,
project_id: &str,
text: &str,
metadata: &str,
force: bool,
memory_type: &str,
status: &str,
) -> Result<serde_json::Value, rmcp::ErrorData> {
let mut store = self.0.lock().unwrap();
let policy = if force {
IngestPolicy::Force
} else {
IngestPolicy::ConflictAware
};
match store
.ingest_with_type_status(
project_id,
text,
Some(metadata),
policy,
memory_type,
status,
)
.map_err(|e| -> rmcp::ErrorData { e.into() })?
{
crate::memory_types::AddResult::Added { id } => {
Ok(serde_json::to_value(SuccessResponse {
id,
status: "added".to_string(),
})
.map_err(McpError::from_serde_error)?)
}
crate::memory_types::AddResult::Conflicts {
proposed,
conflicts,
} => Err(McpError::conflict(
&proposed,
conflicts
.into_iter()
.map(|c: InternalConflictMemory| ConflictMemory {
id: c.id,
content: c.content,
similarity: c.similarity,
})
.collect(),
)),
}
}
#[allow(dead_code)] pub(crate) fn supersede(
&self,
project_id: &str,
new_text: &str,
metadata: &str,
memory_type: &str,
old_memory_id: &str,
) -> Result<serde_json::Value, rmcp::ErrorData> {
let mut store = self.0.lock().unwrap();
let embedding = if store.embedder.is_none() {
crate::memory::crud::mock_embedding_for_content(new_text)
} else {
store.embedder()?.embed(new_text)?
};
let metadata_str = if metadata == "null" {
None
} else {
Some(metadata)
};
let new_id = match store.db.supersede(
project_id,
new_text,
&embedding,
metadata_str,
memory_type,
old_memory_id,
) {
Ok(id) => id,
Err(err) => {
let err: Error = err.into();
return Err(err.into());
}
};
serde_json::to_value(SuccessResponse {
id: new_id,
status: "superseded".to_string(),
})
.map_err(McpError::from_serde_error)
}
pub(crate) fn search(
&self,
project_id: &str,
query: &str,
limit: usize,
recency_weight: f64,
memory_types: Option<&[&str]>,
statuses: Option<&[&str]>,
) -> Result<serde_json::Value, Error> {
let mut store = self.0.lock().unwrap();
let memories = store.search(
project_id,
query,
limit,
recency_weight,
memory_types,
statuses,
)?;
let results: Vec<serde_json::Value> = memories
.into_iter()
.map(|m| {
let metadata_value = match m.metadata {
Some(ref meta_str) if meta_str.trim() != "null" => {
serde_json::from_str::<serde_json::Value>(meta_str)
.unwrap_or_else(|_| serde_json::Value::String(meta_str.clone()))
}
_ => serde_json::Value::Null,
};
serde_json::json!({
"id": m.id,
"content": m.content,
"similarity": m.similarity.unwrap_or(0.0),
"created_at": m.created_at,
"updated_at": m.updated_at,
"project_id": m.project_id,
"metadata": metadata_value
})
})
.collect();
Ok(serde_json::to_value(results)?)
}
#[allow(dead_code)] pub(crate) fn search_hybrid(
&self,
project_id: &str,
query: &str,
limit: usize,
recency_weight: f64,
memory_types: Option<&[&str]>,
statuses: Option<&[&str]>,
) -> Result<serde_json::Value, Error> {
let mut store = self.0.lock().unwrap();
let memories = store.search_hybrid(
project_id,
query,
limit,
recency_weight,
memory_types,
statuses,
)?;
let results: Vec<serde_json::Value> = memories
.into_iter()
.map(|m| {
let metadata_value = match m.metadata {
Some(ref meta_str) if meta_str.trim() != "null" => {
serde_json::from_str::<serde_json::Value>(meta_str)
.unwrap_or_else(|_| serde_json::Value::String(meta_str.clone()))
}
_ => serde_json::Value::Null,
};
serde_json::json!({
"id": m.id,
"content": m.content,
"similarity": m.similarity.unwrap_or(0.0),
"created_at": m.created_at,
"updated_at": m.updated_at,
"project_id": m.project_id,
"metadata": metadata_value
})
})
.collect();
Ok(serde_json::to_value(results)?)
}
pub(crate) fn list(
&self,
project_id: &str,
limit: usize,
memory_types: Option<&[&str]>,
statuses: Option<&[&str]>,
) -> Result<serde_json::Value, Error> {
let store = self.0.lock().unwrap();
let memories = store.list(project_id, limit, memory_types, statuses)?;
let results: Vec<serde_json::Value> = memories
.into_iter()
.map(|m| {
let metadata_value = match m.metadata {
Some(ref meta_str) if meta_str.trim() != "null" => {
serde_json::from_str::<serde_json::Value>(meta_str)
.unwrap_or_else(|_| serde_json::Value::String(meta_str.clone()))
}
_ => serde_json::Value::Null,
};
serde_json::json!({
"id": m.id,
"content": m.content,
"created_at": m.created_at,
"updated_at": m.updated_at,
"project_id": m.project_id,
"metadata": metadata_value
})
})
.collect();
Ok(serde_json::to_value(results)?)
}
}
pub struct ToolHandler {
tool_router: ToolRouter<Self>,
store: StoreWrapper,
project_id: String,
config: crate::config::Config,
}
impl ToolHandler {
pub fn new(
store: Arc<Mutex<MemoryStore>>,
project_id: String,
config: crate::config::Config,
) -> Self {
Self {
tool_router: Self::tool_router(),
store: StoreWrapper(store),
project_id,
config,
}
}
}
#[derive(Debug)]
pub struct McpError;
impl McpError {
pub fn invalid_input(message: &str) -> rmcp::ErrorData {
rmcp::ErrorData::new(
rmcp::model::ErrorCode::INVALID_REQUEST,
message.to_string(),
Some(serde_json::json!({"type": "invalid_input"})),
)
}
pub fn internal_error(message: &str) -> rmcp::ErrorData {
rmcp::ErrorData::new(
rmcp::model::ErrorCode::INTERNAL_ERROR,
message.to_string(),
Some(serde_json::json!({"type": "internal_error"})),
)
}
pub fn conflict(proposed: &str, conflicts: Vec<ConflictMemory>) -> rmcp::ErrorData {
rmcp::ErrorData::new(
rmcp::model::ErrorCode::INVALID_REQUEST,
format!(
"Proposed memory conflicts with {} existing memory(ies)",
conflicts.len()
),
Some(serde_json::json!({
"type": "conflict",
"proposed": proposed,
"conflicts": conflicts
})),
)
}
fn from_serde_error(e: serde_json::Error) -> rmcp::ErrorData {
rmcp::ErrorData::new(
rmcp::model::ErrorCode::INTERNAL_ERROR,
format!("Serialization error: {}", e),
Some(serde_json::json!({"type": "internal_error"})),
)
}
}
impl From<Error> for rmcp::ErrorData {
fn from(e: Error) -> Self {
match e {
Error::EmptyInput => McpError::invalid_input("Text cannot be empty"),
Error::InputTooLong {
max_length,
actual_length,
} => McpError::invalid_input(&format!(
"Input too long: {} characters (max: {})",
actual_length, max_length
)),
Error::InvalidInput(msg) => McpError::invalid_input(&msg),
Error::Validation(msg) => McpError::invalid_input(&msg),
Error::InvalidTimestamp { timestamp, error } => {
McpError::invalid_input(&format!("Invalid timestamp '{}': {}", timestamp, error))
}
Error::ContentTooLong {
token_count,
max_tokens,
} => McpError::invalid_input(&format!(
"Content exceeds {}-token embedding limit (measured: {} tokens)",
max_tokens, token_count
)),
Error::NotFound(msg) => rmcp::ErrorData::new(
rmcp::model::ErrorCode::INVALID_REQUEST,
msg,
Some(serde_json::json!({"type": "not_found"})),
),
_ => McpError::internal_error(&e.to_string()),
}
}
}
#[tool_router]
impl ToolHandler {
#[tool(
name = "store_memory",
description = "Store information for later recall. Use this when you learn something worth remembering — facts, decisions, preferences, or context about a project. Memories are searchable by meaning, not just keywords."
)]
async fn store_memory(
&self,
Parameters(params): Parameters<StoreMemoryParams>,
) -> Result<CallToolResult, rmcp::ErrorData> {
if params.text.trim().is_empty() {
return Err(McpError::invalid_input("Text cannot be empty"));
}
let memory_type = params.memory_type.as_deref().unwrap_or("fact");
let _ = MemoryType::from_str(memory_type)
.map_err(|e| McpError::invalid_input(&format!("Invalid memory type: {}", e)))?;
let status_str = params.status.as_deref().unwrap_or("active");
let status_val = MemoryStatus::from_str(status_str)
.map_err(|e| McpError::invalid_input(&format!("Invalid status: {}", e)))?;
if !status_val.is_valid_for_insert() {
return Err(McpError::invalid_input(&format!(
"Status '{}' is not valid for new memory. Must be 'active' or 'candidate'.",
status_str
)));
}
let metadata_str = match ¶ms.metadata {
Some(meta) => serde_json::to_string(meta)
.map_err(|e| McpError::invalid_input(&format!("Invalid metadata: {}", e)))?,
None => "null".to_string(),
};
if let Some(supersedes_id) = ¶ms.supersedes {
let value = self.store.supersede(
&self.project_id,
¶ms.text,
&metadata_str,
memory_type,
supersedes_id,
)?;
return Ok(CallToolResult::success(vec![Content::text(
serde_json::to_string(&value).map_err(McpError::from_serde_error)?,
)]));
}
let value = self.store.ingest_with_type_status(
&self.project_id,
¶ms.text,
&metadata_str,
false,
memory_type,
status_str,
)?;
Ok(CallToolResult::success(vec![Content::text(
serde_json::to_string(&value).map_err(McpError::from_serde_error)?,
)]))
}
#[tool(
name = "search_memories",
description = "Search memories by meaning. Describe what you are looking for in natural language. Use this when you need to recall previously stored information, check if a topic was discussed, or find related context."
)]
async fn search_memories(
&self,
Parameters(params): Parameters<SearchMemoriesParams>,
) -> Result<CallToolResult, rmcp::ErrorData> {
if params.query.trim().is_empty() {
return Err(McpError::invalid_input("Query cannot be empty"));
}
let limit = params.limit.unwrap_or(5);
if limit == 0 {
return Err(McpError::invalid_input("Limit must be greater than 0"));
}
if limit > 10_000 {
return Err(McpError::invalid_input(
"Limit exceeds maximum allowed (10000)",
));
}
let type_strs: Option<Vec<&str>> = params
.memory_types
.as_ref()
.map(|v| v.iter().map(|s| s.as_str()).collect());
let type_slice: Option<&[&str]> = type_strs.as_deref();
let status_strs: Option<Vec<&str>> = params
.statuses
.as_ref()
.map(|v| v.iter().map(|s| s.as_str()).collect());
let status_slice: Option<&[&str]> = status_strs.as_deref();
let recency_weight = params.recency_weight.unwrap_or(self.config.recency_weight);
let use_hybrid = params.hybrid.unwrap_or(self.config.hybrid);
let value = if use_hybrid {
self.store
.search_hybrid(
&self.project_id,
¶ms.query,
limit,
recency_weight,
type_slice,
status_slice,
)
.map_err(|e: Error| -> rmcp::ErrorData { e.into() })?
} else {
self.store
.search(
&self.project_id,
¶ms.query,
limit,
recency_weight,
type_slice,
status_slice,
)
.map_err(|e: Error| -> rmcp::ErrorData { e.into() })?
};
Ok(CallToolResult::success(vec![Content::text(
serde_json::to_string(&value).map_err(McpError::from_serde_error)?,
)]))
}
#[tool(
name = "list_memories",
description = "List recent memories. Use this to review what was recently stored, get an overview of stored knowledge, or find memories when you are not sure what to search for."
)]
async fn list_memories(
&self,
Parameters(params): Parameters<ListMemoriesParams>,
) -> Result<CallToolResult, rmcp::ErrorData> {
let limit = params.limit.unwrap_or(10);
if limit == 0 {
return Err(McpError::invalid_input("Limit must be greater than 0"));
}
if limit > 10_000 {
return Err(McpError::invalid_input(
"Limit exceeds maximum allowed (10000)",
));
}
let type_strs: Option<Vec<&str>> = params
.memory_types
.as_ref()
.map(|v| v.iter().map(|s| s.as_str()).collect());
let type_slice: Option<&[&str]> = type_strs.as_deref();
let status_strs: Option<Vec<&str>> = params
.statuses
.as_ref()
.map(|v| v.iter().map(|s| s.as_str()).collect());
let status_slice: Option<&[&str]> = status_strs.as_deref();
let value = self
.store
.list(&self.project_id, limit, type_slice, status_slice)
.map_err(|e: Error| -> rmcp::ErrorData { e.into() })?;
Ok(CallToolResult::success(vec![Content::text(
serde_json::to_string(&value).map_err(McpError::from_serde_error)?,
)]))
}
#[tool(
name = "supersede_memory",
description = "Replace an existing memory with new content. The old memory is marked as superseded and a new memory is created. Use this when information has changed and the old version should no longer appear in search results."
)]
async fn supersede_memory(
&self,
Parameters(params): Parameters<SupersedeMemoryParams>,
) -> Result<CallToolResult, rmcp::ErrorData> {
if params.new_text.trim().is_empty() {
return Err(McpError::invalid_input("new_text cannot be empty"));
}
if params.old_memory_id.trim().is_empty() {
return Err(McpError::invalid_input("old_memory_id cannot be empty"));
}
let memory_type = params.memory_type.as_deref().unwrap_or("fact");
let _ = MemoryType::from_str(memory_type)
.map_err(|e| McpError::invalid_input(&format!("Invalid memory type: {}", e)))?;
let metadata_str = match ¶ms.metadata {
Some(meta) => serde_json::to_string(meta)
.map_err(|e| McpError::invalid_input(&format!("Invalid metadata: {}", e)))?,
None => "null".to_string(),
};
let value = self.store.supersede(
&self.project_id,
¶ms.new_text,
&metadata_str,
memory_type,
¶ms.old_memory_id,
)?;
Ok(CallToolResult::success(vec![Content::text(
serde_json::to_string(&value).map_err(McpError::from_serde_error)?,
)]))
}
}
#[tool_handler]
impl rmcp::ServerHandler for ToolHandler {
fn get_info(&self) -> ServerInfo {
ServerInfo::new(ServerCapabilities::builder().enable_tools().build())
}
}