use std::{sync::Arc, time::Instant};
const SNIPPET_MAX_CHARS: usize = 500;
use chrono::Utc;
use rmcp::{
handler::server::{router::tool::ToolRouter, wrapper::Parameters},
model::{ErrorData, ServerCapabilities, ServerInfo},
tool, tool_handler, tool_router, ServerHandler,
};
use tracing::{info, warn, Instrument};
use crate::{
embedding::EmbeddingBackend,
error::MemoryError,
index::ScopedIndex,
repo::MemoryRepo,
types::{
parse_qualified_name, parse_scope, parse_scope_filter, validate_name, AppState,
ChangedMemories, EditArgs, ForgetArgs, ListArgs, Memory, MemoryMetadata, PullResult,
ReadArgs, RecallArgs, ReindexStats, RememberArgs, Scope, ScopeFilter, SyncArgs,
},
};
#[derive(Clone)]
pub struct MemoryServer {
state: Arc<AppState>,
#[allow(dead_code)]
tool_router: ToolRouter<Self>,
}
const MAX_CONTENT_SIZE: usize = 1_048_576;
async fn incremental_reindex(
repo: &Arc<MemoryRepo>,
embedding: &dyn EmbeddingBackend,
index: &ScopedIndex,
changes: &ChangedMemories,
) -> ReindexStats {
let mut stats = ReindexStats::default();
for name in &changes.removed {
match parse_qualified_name(name) {
Ok((scope, _)) => {
if let Err(e) = index.remove(&scope, name) {
warn!(
qualified_name = %name,
error = %e,
"incremental_reindex: failed to remove vector; skipping"
);
stats.errors += 1;
} else {
stats.removed += 1;
}
}
Err(e) => {
warn!(
qualified_name = %name,
error = %e,
"incremental_reindex: cannot parse qualified name for removal; skipping"
);
}
}
}
let mut pairs: Vec<(Scope, String, String)> = Vec::new(); for qualified in &changes.upserted {
match parse_qualified_name(qualified) {
Ok((scope, name)) => pairs.push((scope, name, qualified.clone())),
Err(e) => {
warn!(
qualified_name = %qualified,
error = %e,
"incremental_reindex: cannot parse qualified name; skipping"
);
stats.errors += 1;
}
}
}
let mut to_embed: Vec<(Scope, String, String)> = Vec::new();
for (scope, name, qualified) in &pairs {
let memory = match repo.read_memory(name, scope).await {
Ok(m) => m,
Err(e) => {
warn!(
qualified_name = %qualified,
error = %e,
"incremental_reindex: failed to read memory; skipping"
);
stats.errors += 1;
continue;
}
};
to_embed.push((scope.clone(), qualified.clone(), memory.content));
}
if to_embed.is_empty() {
return stats;
}
let contents: Vec<String> = to_embed.iter().map(|(_, _, c)| c.clone()).collect();
let vectors = match embedding.embed(&contents).await {
Ok(v) => v,
Err(batch_err) => {
warn!(error = %batch_err, "incremental_reindex: batch embed failed; falling back to per-item");
let mut vecs: Vec<Vec<f32>> = Vec::with_capacity(contents.len());
let mut failed: Vec<usize> = Vec::new();
for (i, content) in contents.iter().enumerate() {
match embedding.embed(std::slice::from_ref(content)).await {
Ok(mut v) => vecs.push(v.remove(0)),
Err(e) => {
warn!(
error = %e,
qualified_name = %to_embed[i].1,
"incremental_reindex: per-item embed failed; skipping"
);
failed.push(i);
stats.errors += 1;
}
}
}
for &i in failed.iter().rev() {
to_embed.remove(i);
}
vecs
}
};
for ((scope, qualified_name, _), vector) in to_embed.iter().zip(vectors.iter()) {
let is_update = index.find_key_by_name(qualified_name).is_some();
match index.add(scope, vector, qualified_name.clone()) {
Ok(_) => {}
Err(e) => {
warn!(
qualified_name = %qualified_name,
error = %e,
"incremental_reindex: add failed; skipping"
);
stats.errors += 1;
continue;
}
}
if is_update {
stats.updated += 1;
} else {
stats.added += 1;
}
}
stats
}
#[tool_router]
impl MemoryServer {
pub fn new(state: Arc<AppState>) -> Self {
Self {
state,
tool_router: Self::tool_router(),
}
}
#[tool(
name = "remember",
description = "Store a new memory. Saves the content to the git-backed repository and \
indexes it for semantic search. Use scope 'project:<basename-of-your-cwd>' for \
project-specific memories or omit for global. Returns the assigned memory ID. \
IMPORTANT: Never store credentials, API keys, tokens, passwords, or other secrets — \
memories are plaintext files in a git repo and may be synced to a remote."
)]
async fn remember(
&self,
Parameters(args): Parameters<RememberArgs>,
) -> Result<String, ErrorData> {
validate_name(&args.name).map_err(ErrorData::from)?;
if args.content.len() > MAX_CONTENT_SIZE {
return Err(ErrorData::from(crate::error::MemoryError::InvalidInput {
reason: format!(
"content size {} exceeds maximum of {} bytes",
args.content.len(),
MAX_CONTENT_SIZE
),
}));
}
let span = tracing::info_span!(
"remember",
name = %args.name,
scope = ?args.scope,
);
let state = Arc::clone(&self.state);
async move {
let scope = parse_scope(args.scope.as_deref()).map_err(ErrorData::from)?;
let metadata = MemoryMetadata::new(scope.clone(), args.tags, args.source);
let memory = Memory::new(args.name, args.content, metadata);
let start = Instant::now();
let vector = state
.embedding
.embed_one(&memory.content)
.await
.map_err(ErrorData::from)?;
info!(embed_ms = start.elapsed().as_millis(), "embedded");
let qualified_name = format!("{}/{}", memory.metadata.scope.dir_prefix(), memory.name);
state
.index
.add(&scope, &vector, qualified_name)
.map_err(ErrorData::from)?;
let start = Instant::now();
state
.repo
.save_memory(&memory)
.await
.map_err(ErrorData::from)?;
info!(repo_ms = start.elapsed().as_millis(), "saved to repo");
Ok(serde_json::json!({
"id": memory.id,
"name": memory.name,
"scope": memory.metadata.scope.to_string(),
})
.to_string())
}
.instrument(span)
.await
}
#[tool(
name = "recall",
description = "Search memories by semantic similarity. Embeds the query and returns the top matching memories as a JSON array \
with name, scope, tags, and a content snippet (max 500 chars).\n\n\
Each result includes `truncated` (bool) and `content_length` (total character count). \
When `truncated` is true, the snippet is incomplete — use the `read` tool with the memory's name and scope \
to retrieve the full content before acting on it.\n\n\
Scope: pass 'project:<basename-of-your-cwd>' to search your current project + global memories, \
'global' for global-only, or 'all' to search everything. Omitting scope defaults to global-only."
)]
async fn recall(&self, Parameters(args): Parameters<RecallArgs>) -> Result<String, ErrorData> {
let span = tracing::info_span!(
"recall",
query = %args.query,
scope = ?args.scope,
limit = ?args.limit,
);
let state = Arc::clone(&self.state);
async move {
let scope_filter =
parse_scope_filter(args.scope.as_deref()).map_err(ErrorData::from)?;
let limit = args.limit.unwrap_or(5).min(100);
let start = Instant::now();
let query_vector = state
.embedding
.embed_one(&args.query)
.await
.map_err(ErrorData::from)?;
info!(embed_ms = start.elapsed().as_millis(), "query embedded");
let start = Instant::now();
let results = state
.index
.search(&scope_filter, &query_vector, limit)
.map_err(ErrorData::from)?;
info!(
search_ms = start.elapsed().as_millis(),
candidates = results.len(),
"index searched"
);
let pre_filter_count = results.len();
let mut results_vec = Vec::new();
let mut skipped_errors: usize = 0;
for (_key, qualified_name, distance) in results {
if results_vec.len() >= limit {
break;
}
let (scope, name) = match parse_qualified_name(&qualified_name) {
Ok(pair) => pair,
Err(e) => {
warn!(
qualified_name = %qualified_name,
error = %e,
"could not parse qualified name from index; skipping"
);
skipped_errors += 1;
continue;
}
};
let memory = match state.repo.read_memory(&name, &scope).await {
Ok(m) => m,
Err(e) => {
warn!(
name = %name,
error = %e,
"could not read memory from repo (deleted?); skipping"
);
skipped_errors += 1;
continue;
}
};
let (snippet, content_length, truncated) = build_snippet(&memory.content);
results_vec.push(serde_json::json!({
"id": memory.id,
"name": memory.name,
"scope": memory.metadata.scope.to_string(),
"tags": memory.metadata.tags,
"content": snippet,
"content_length": content_length,
"truncated": truncated,
"distance": distance,
}));
}
info!(
returned = results_vec.len(),
skipped_errors, "recall complete"
);
Ok(serde_json::json!({
"results": results_vec,
"count": results_vec.len(),
"limit": limit,
"pre_filter_count": pre_filter_count,
"skipped_errors": skipped_errors,
})
.to_string())
}
.instrument(span)
.await
}
#[tool(
name = "forget",
description = "Delete a memory by name. Use scope 'project:<basename-of-your-cwd>' for project-scoped \
memories or omit for global. Removes the file from git and the vector from the search index. \
Returns 'ok' on success."
)]
async fn forget(&self, Parameters(args): Parameters<ForgetArgs>) -> Result<String, ErrorData> {
validate_name(&args.name).map_err(ErrorData::from)?;
let span = tracing::info_span!(
"forget",
name = %args.name,
scope = ?args.scope,
);
let state = Arc::clone(&self.state);
async move {
let scope = parse_scope(args.scope.as_deref()).map_err(ErrorData::from)?;
let start = Instant::now();
state
.repo
.delete_memory(&args.name, &scope)
.await
.map_err(ErrorData::from)?;
let qualified_name = format!("{}/{}", scope.dir_prefix(), args.name);
if let Err(e) = state.index.remove(&scope, &qualified_name) {
warn!(name = %args.name, error = %e, "vector removal failed during forget; stale entry will be skipped at recall");
}
info!(
ms = start.elapsed().as_millis(),
name = %args.name,
"memory forgotten"
);
Ok("ok".to_string())
}
.instrument(span)
.await
}
#[tool(
name = "edit",
description = "Edit an existing memory. Supports partial updates — omit content or \
tags to preserve existing values. Re-embeds and re-indexes the memory. Use scope \
'project:<basename-of-your-cwd>' for project-scoped memories. Returns the memory ID. \
IMPORTANT: Never store credentials, API keys, tokens, passwords, or other secrets — \
memories are plaintext files in a git repo and may be synced to a remote."
)]
async fn edit(&self, Parameters(args): Parameters<EditArgs>) -> Result<String, ErrorData> {
validate_name(&args.name).map_err(ErrorData::from)?;
if let Some(ref content) = args.content {
if content.len() > MAX_CONTENT_SIZE {
return Err(ErrorData::from(crate::error::MemoryError::InvalidInput {
reason: format!(
"content size {} exceeds maximum of {} bytes",
content.len(),
MAX_CONTENT_SIZE
),
}));
}
}
let span = tracing::info_span!(
"edit",
name = %args.name,
scope = ?args.scope,
);
let state = Arc::clone(&self.state);
async move {
let scope = parse_scope(args.scope.as_deref()).map_err(ErrorData::from)?;
let start = Instant::now();
let content_changed = args.content.is_some();
let mut memory = state
.repo
.read_memory(&args.name, &scope)
.await
.map_err(ErrorData::from)?;
if let Some(content) = args.content {
memory.content = content;
}
if let Some(tags) = args.tags {
memory.metadata.tags = tags;
}
memory.metadata.updated_at = Utc::now();
if content_changed {
let qualified_name =
format!("{}/{}", memory.metadata.scope.dir_prefix(), memory.name);
let vector = state
.embedding
.embed_one(&memory.content)
.await
.map_err(ErrorData::from)?;
state
.index
.add(&scope, &vector, qualified_name)
.map_err(ErrorData::from)?;
}
state
.repo
.save_memory(&memory)
.await
.map_err(ErrorData::from)?;
info!(
ms = start.elapsed().as_millis(),
name = %args.name,
content_changed,
"memory edited"
);
Ok(serde_json::json!({
"id": memory.id,
"name": memory.name,
"scope": memory.metadata.scope.to_string(),
})
.to_string())
}
.instrument(span)
.await
}
#[tool(
name = "list",
description = "List stored memories. Pass 'project:<basename-of-your-cwd>' for project + global memories, \
'global' for global-only, or 'all' for everything. Omitting scope defaults to global-only. \
Returns a JSON array of memory summaries without full content."
)]
async fn list(&self, Parameters(args): Parameters<ListArgs>) -> Result<String, ErrorData> {
let span = tracing::info_span!("list", scope = ?args.scope);
let state = Arc::clone(&self.state);
async move {
let scope_filter =
parse_scope_filter(args.scope.as_deref()).map_err(ErrorData::from)?;
let start = Instant::now();
let memories = match &scope_filter {
ScopeFilter::GlobalOnly => state
.repo
.list_memories(Some(&Scope::Global))
.await
.map_err(ErrorData::from)?,
ScopeFilter::All => state
.repo
.list_memories(None)
.await
.map_err(ErrorData::from)?,
ScopeFilter::ProjectAndGlobal(project_name) => {
let project_scope = Scope::Project(project_name.clone());
let mut global = state
.repo
.list_memories(Some(&Scope::Global))
.await
.map_err(ErrorData::from)?;
let project = state
.repo
.list_memories(Some(&project_scope))
.await
.map_err(ErrorData::from)?;
global.extend(project);
global
}
};
info!(
ms = start.elapsed().as_millis(),
count = memories.len(),
"listed memories"
);
let summaries: Vec<serde_json::Value> = memories
.into_iter()
.map(|m| {
serde_json::json!({
"id": m.id,
"name": m.name,
"scope": m.metadata.scope.to_string(),
"tags": m.metadata.tags,
"created_at": m.metadata.created_at,
"updated_at": m.metadata.updated_at,
})
})
.collect();
Ok(serde_json::json!({
"memories": summaries,
"count": summaries.len(),
})
.to_string())
}
.instrument(span)
.await
}
#[tool(
name = "read",
description = "Read a specific memory by name. Use scope 'project:<basename-of-your-cwd>' for \
project-scoped memories or omit for global. Returns the full markdown content and metadata \
(id, scope, tags, timestamps) as a JSON object."
)]
async fn read(&self, Parameters(args): Parameters<ReadArgs>) -> Result<String, ErrorData> {
validate_name(&args.name).map_err(ErrorData::from)?;
let span = tracing::info_span!("read", name = %args.name, scope = ?args.scope);
let state = Arc::clone(&self.state);
async move {
let scope = parse_scope(args.scope.as_deref()).map_err(ErrorData::from)?;
let start = Instant::now();
let memory = state
.repo
.read_memory(&args.name, &scope)
.await
.map_err(ErrorData::from)?;
info!(
ms = start.elapsed().as_millis(),
name = %args.name,
"read memory"
);
Ok(serde_json::json!({
"id": memory.id,
"name": memory.name,
"scope": memory.metadata.scope.to_string(),
"tags": memory.metadata.tags,
"content": memory.content,
"source": memory.metadata.source,
"created_at": memory.metadata.created_at,
"updated_at": memory.metadata.updated_at,
})
.to_string())
}
.instrument(span)
.await
}
#[tool(
name = "sync",
description = "Sync the memory repo with the git remote (push/pull). Requires \
MEMORY_MCP_GITHUB_TOKEN or a token file. Returns a status message."
)]
async fn sync(&self, Parameters(args): Parameters<SyncArgs>) -> Result<String, ErrorData> {
let pull_first = args.pull_first.unwrap_or(true);
let span = tracing::info_span!("sync", pull_first);
let state = Arc::clone(&self.state);
async move {
let start = Instant::now();
let branch = &state.branch;
let mut has_remote = true;
let mut reindex_stats: Option<ReindexStats> = None;
let pull_status = if pull_first {
let result = state
.repo
.pull(&state.auth, branch)
.await
.map_err(ErrorData::from)?;
let mut oid_range: Option<([u8; 20], [u8; 20])> = None;
let status = match result {
PullResult::NoRemote => {
has_remote = false;
"no-remote".to_string()
}
PullResult::UpToDate => "up-to-date".to_string(),
PullResult::FastForward { old_head, new_head } => {
oid_range = Some((old_head, new_head));
"fast-forward".to_string()
}
PullResult::Merged {
conflicts_resolved,
old_head,
new_head,
} => {
oid_range = Some((old_head, new_head));
format!("merged ({} conflicts resolved)", conflicts_resolved)
}
};
if let Some((old_head, new_head)) = oid_range {
let repo = Arc::clone(&state.repo);
let changes = tokio::task::spawn_blocking(move || {
repo.diff_changed_memories(old_head, new_head)
})
.await
.map_err(|e| MemoryError::Join(e.to_string()))
.map_err(ErrorData::from)?
.map_err(ErrorData::from)?;
if !changes.is_empty() {
let stats = incremental_reindex(
&state.repo,
state.embedding.as_ref(),
&state.index,
&changes,
)
.await;
info!(
added = stats.added,
updated = stats.updated,
removed = stats.removed,
errors = stats.errors,
"incremental reindex complete"
);
reindex_stats = Some(stats);
}
}
status
} else {
"skipped".to_string()
};
if has_remote {
state
.repo
.push(&state.auth, branch)
.await
.map_err(ErrorData::from)?;
}
info!(
ms = start.elapsed().as_millis(),
pull_first,
pull_status = %pull_status,
"sync complete"
);
let mut response = serde_json::json!({
"status": "sync complete",
"pull": pull_status,
"branch": branch,
});
if let Some(stats) = reindex_stats {
response["reindex"] = serde_json::json!({
"added": stats.added,
"updated": stats.updated,
"removed": stats.removed,
"errors": stats.errors,
});
}
Ok(response.to_string())
}
.instrument(span)
.await
}
}
#[tool_handler]
impl ServerHandler for MemoryServer {
fn get_info(&self) -> ServerInfo {
ServerInfo::new(ServerCapabilities::builder().enable_tools().build()).with_instructions(
"A semantic memory system for AI coding agents. Memories are stored as markdown files \
in a git repository and indexed for semantic retrieval. Use `remember` to store, `recall` \
to search, `read` to fetch a specific memory, `edit` to update, `forget` to delete, \
`list` to browse, and `sync` to push/pull the remote.\n\n\
Scope convention: always pass scope='project:<basename-of-your-cwd>' when working within \
a project. This returns project memories alongside global ones. Omitting scope defaults to \
global-only for queries (recall, list) and targets a single memory for point operations \
(remember, edit, read, forget). Use scope='all' to search across all projects.\n\n\
IMPORTANT: Never store credentials, API keys, tokens, passwords, or other secrets in \
memory content. Memories are stored as plaintext markdown files committed to a git \
repository and may be synced to a remote. Treat all memory content as public."
.to_string(),
)
}
}
fn build_snippet(content: &str) -> (String, usize, bool) {
let content_length = content.chars().count();
let truncated = content_length > SNIPPET_MAX_CHARS;
let snippet: String = content.chars().take(SNIPPET_MAX_CHARS).collect();
(snippet, content_length, truncated)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn snippet_short_content_not_truncated() {
let content = "Hello, world!";
let (snippet, content_length, truncated) = build_snippet(content);
assert_eq!(snippet, "Hello, world!");
assert_eq!(content_length, 13);
assert!(!truncated);
}
#[test]
fn snippet_exact_limit_not_truncated() {
let content: String = "a".repeat(SNIPPET_MAX_CHARS);
let (snippet, content_length, truncated) = build_snippet(&content);
assert_eq!(snippet, content);
assert_eq!(content_length, SNIPPET_MAX_CHARS);
assert!(!truncated);
}
#[test]
fn snippet_over_limit_is_truncated() {
let content: String = "b".repeat(SNIPPET_MAX_CHARS + 100);
let (snippet, content_length, truncated) = build_snippet(&content);
assert_eq!(snippet.chars().count(), SNIPPET_MAX_CHARS);
assert_eq!(content_length, SNIPPET_MAX_CHARS + 100);
assert!(truncated);
}
#[test]
fn snippet_counts_unicode_chars_not_bytes() {
let emoji_content: String = "\u{1F600}".repeat(SNIPPET_MAX_CHARS + 1);
let (snippet, content_length, truncated) = build_snippet(&emoji_content);
assert_eq!(snippet.chars().count(), SNIPPET_MAX_CHARS);
assert_eq!(content_length, SNIPPET_MAX_CHARS + 1);
assert!(truncated);
}
#[test]
fn snippet_empty_content() {
let (snippet, content_length, truncated) = build_snippet("");
assert_eq!(snippet, "");
assert_eq!(content_length, 0);
assert!(!truncated);
}
}