1use std::{sync::Arc, time::Instant};
2
3use chrono::Utc;
4use rmcp::{
5 handler::server::{router::tool::ToolRouter, wrapper::Parameters},
6 model::{ErrorData, ServerCapabilities, ServerInfo},
7 tool, tool_handler, tool_router, ServerHandler,
8};
9use tracing::{info, warn, Instrument};
10
11use crate::{
12 embedding::EmbeddingBackend,
13 error::MemoryError,
14 index::ScopedIndex,
15 repo::MemoryRepo,
16 types::{
17 parse_qualified_name, parse_scope, parse_scope_filter, validate_name, AppState,
18 ChangedMemories, EditArgs, ForgetArgs, ListArgs, Memory, MemoryMetadata, PullResult,
19 ReadArgs, RecallArgs, ReindexStats, RememberArgs, Scope, ScopeFilter, SyncArgs,
20 },
21};
22
23#[derive(Clone)]
28pub struct MemoryServer {
29 state: Arc<AppState>,
30 tool_router: ToolRouter<Self>,
31}
32
33const MAX_CONTENT_SIZE: usize = 1_048_576;
35
36async fn incremental_reindex(
45 repo: &Arc<MemoryRepo>,
46 embedding: &dyn EmbeddingBackend,
47 index: &ScopedIndex,
48 changes: &ChangedMemories,
49) -> ReindexStats {
50 let mut stats = ReindexStats::default();
51
52 for name in &changes.removed {
54 match parse_qualified_name(name) {
55 Ok((scope, _)) => {
56 if let Err(e) = index.remove(&scope, name) {
57 warn!(
58 qualified_name = %name,
59 error = %e,
60 "incremental_reindex: failed to remove vector; skipping"
61 );
62 stats.errors += 1;
63 } else {
64 stats.removed += 1;
65 }
66 }
67 Err(e) => {
68 warn!(
69 qualified_name = %name,
70 error = %e,
71 "incremental_reindex: cannot parse qualified name for removal; skipping"
72 );
73 }
75 }
76 }
78
79 let mut pairs: Vec<(Scope, String, String)> = Vec::new(); for qualified in &changes.upserted {
84 match parse_qualified_name(qualified) {
85 Ok((scope, name)) => pairs.push((scope, name, qualified.clone())),
86 Err(e) => {
87 warn!(
88 qualified_name = %qualified,
89 error = %e,
90 "incremental_reindex: cannot parse qualified name; skipping"
91 );
92 stats.errors += 1;
93 }
94 }
95 }
96
97 let mut to_embed: Vec<(Scope, String, String)> = Vec::new();
100 for (scope, name, qualified) in &pairs {
101 let memory = match repo.read_memory(name, scope).await {
102 Ok(m) => m,
103 Err(e) => {
104 warn!(
105 qualified_name = %qualified,
106 error = %e,
107 "incremental_reindex: failed to read memory; skipping"
108 );
109 stats.errors += 1;
110 continue;
111 }
112 };
113 to_embed.push((scope.clone(), qualified.clone(), memory.content));
114 }
115
116 if to_embed.is_empty() {
117 return stats;
118 }
119
120 let contents: Vec<String> = to_embed.iter().map(|(_, _, c)| c.clone()).collect();
122 let vectors = match embedding.embed(&contents).await {
123 Ok(v) => v,
124 Err(batch_err) => {
125 warn!(error = %batch_err, "incremental_reindex: batch embed failed; falling back to per-item");
126 let mut vecs: Vec<Vec<f32>> = Vec::with_capacity(contents.len());
127 let mut failed: Vec<usize> = Vec::new();
128 for (i, content) in contents.iter().enumerate() {
129 match embedding.embed(std::slice::from_ref(content)).await {
130 Ok(mut v) => vecs.push(v.remove(0)),
131 Err(e) => {
132 warn!(
133 error = %e,
134 qualified_name = %to_embed[i].1,
135 "incremental_reindex: per-item embed failed; skipping"
136 );
137 failed.push(i);
138 stats.errors += 1;
139 }
140 }
141 }
142 for &i in failed.iter().rev() {
144 to_embed.remove(i);
145 }
146 vecs
147 }
148 };
149
150 for ((scope, qualified_name, _), vector) in to_embed.iter().zip(vectors.iter()) {
152 let is_update = index.find_key_by_name(qualified_name).is_some();
153
154 match index.add(scope, vector, qualified_name.clone()) {
155 Ok(_) => {}
156 Err(e) => {
157 warn!(
158 qualified_name = %qualified_name,
159 error = %e,
160 "incremental_reindex: add failed; skipping"
161 );
162 stats.errors += 1;
163 continue;
164 }
165 }
166
167 if is_update {
168 stats.updated += 1;
169 } else {
170 stats.added += 1;
171 }
172 }
173
174 stats
175}
176
177#[tool_router]
178impl MemoryServer {
179 pub fn new(state: Arc<AppState>) -> Self {
181 Self {
182 state,
183 tool_router: Self::tool_router(),
184 }
185 }
186
187 #[tool(
194 name = "remember",
195 description = "Store a new memory. Saves the content to the git-backed repository and \
196 indexes it for semantic search. Use scope 'project:<basename-of-your-cwd>' for \
197 project-specific memories or omit for global. Returns the assigned memory ID."
198 )]
199 async fn remember(
200 &self,
201 Parameters(args): Parameters<RememberArgs>,
202 ) -> Result<String, ErrorData> {
203 validate_name(&args.name).map_err(ErrorData::from)?;
204 if args.content.len() > MAX_CONTENT_SIZE {
205 return Err(ErrorData::from(crate::error::MemoryError::InvalidInput {
206 reason: format!(
207 "content size {} exceeds maximum of {} bytes",
208 args.content.len(),
209 MAX_CONTENT_SIZE
210 ),
211 }));
212 }
213 let span = tracing::info_span!(
214 "remember",
215 name = %args.name,
216 scope = ?args.scope,
217 );
218 let state = Arc::clone(&self.state);
219 async move {
220 let scope = parse_scope(args.scope.as_deref()).map_err(ErrorData::from)?;
221 let metadata = MemoryMetadata::new(scope.clone(), args.tags, args.source);
222 let memory = Memory::new(args.name, args.content, metadata);
223
224 let start = Instant::now();
228 let vector = state
229 .embedding
230 .embed_one(&memory.content)
231 .await
232 .map_err(ErrorData::from)?;
233 info!(embed_ms = start.elapsed().as_millis(), "embedded");
234
235 let qualified_name = format!("{}/{}", memory.metadata.scope.dir_prefix(), memory.name);
236
237 state
238 .index
239 .add(&scope, &vector, qualified_name)
240 .map_err(ErrorData::from)?;
241
242 let start = Instant::now();
243 state
244 .repo
245 .save_memory(&memory)
246 .await
247 .map_err(ErrorData::from)?;
248 info!(repo_ms = start.elapsed().as_millis(), "saved to repo");
249
250 Ok(serde_json::json!({
251 "id": memory.id,
252 "name": memory.name,
253 "scope": memory.metadata.scope.to_string(),
254 })
255 .to_string())
256 }
257 .instrument(span)
258 .await
259 }
260
261 #[tool(
268 name = "recall",
269 description = "Search memories by semantic similarity. Returns the top matching memories as a JSON array \
270 with name, scope, tags, and content snippet.\n\n\
271 Scope: pass 'project:<basename-of-your-cwd>' to search your current project + global memories, \
272 'global' for global-only, or 'all' to search everything. Omitting scope defaults to global-only."
273 )]
274 async fn recall(&self, Parameters(args): Parameters<RecallArgs>) -> Result<String, ErrorData> {
275 let span = tracing::info_span!(
276 "recall",
277 query = %args.query,
278 scope = ?args.scope,
279 limit = ?args.limit,
280 );
281 let state = Arc::clone(&self.state);
282 async move {
283 let scope_filter =
285 parse_scope_filter(args.scope.as_deref()).map_err(ErrorData::from)?;
286
287 let limit = args.limit.unwrap_or(5).min(100);
288
289 let start = Instant::now();
290 let query_vector = state
291 .embedding
292 .embed_one(&args.query)
293 .await
294 .map_err(ErrorData::from)?;
295 info!(embed_ms = start.elapsed().as_millis(), "query embedded");
296
297 let start = Instant::now();
298 let results = state
299 .index
300 .search(&scope_filter, &query_vector, limit)
301 .map_err(ErrorData::from)?;
302 info!(
303 search_ms = start.elapsed().as_millis(),
304 candidates = results.len(),
305 "index searched"
306 );
307
308 let pre_filter_count = results.len();
309 let mut results_vec = Vec::new();
310 let mut skipped_errors: usize = 0;
311
312 for (_key, qualified_name, distance) in results {
313 if results_vec.len() >= limit {
316 break;
317 }
318 let (scope, name) = match parse_qualified_name(&qualified_name) {
319 Ok(pair) => pair,
320 Err(e) => {
321 warn!(
322 qualified_name = %qualified_name,
323 error = %e,
324 "could not parse qualified name from index; skipping"
325 );
326 skipped_errors += 1;
327 continue;
328 }
329 };
330
331 let memory = match state.repo.read_memory(&name, &scope).await {
333 Ok(m) => m,
334 Err(e) => {
335 warn!(
336 name = %name,
337 error = %e,
338 "could not read memory from repo (deleted?); skipping"
339 );
340 skipped_errors += 1;
341 continue;
342 }
343 };
344
345 let snippet: String = memory.content.chars().take(500).collect();
347
348 results_vec.push(serde_json::json!({
349 "id": memory.id,
350 "name": memory.name,
351 "scope": memory.metadata.scope.to_string(),
352 "tags": memory.metadata.tags,
353 "content": snippet,
354 "distance": distance,
355 }));
356 }
357
358 info!(
359 returned = results_vec.len(),
360 skipped_errors, "recall complete"
361 );
362
363 Ok(serde_json::json!({
364 "results": results_vec,
365 "count": results_vec.len(),
366 "limit": limit,
367 "pre_filter_count": pre_filter_count,
368 "skipped_errors": skipped_errors,
369 })
370 .to_string())
371 }
372 .instrument(span)
373 .await
374 }
375
376 #[tool(
383 name = "forget",
384 description = "Delete a memory by name. Use scope 'project:<basename-of-your-cwd>' for project-scoped \
385 memories or omit for global. Removes the file from git and the vector from the search index. \
386 Returns 'ok' on success."
387 )]
388 async fn forget(&self, Parameters(args): Parameters<ForgetArgs>) -> Result<String, ErrorData> {
389 validate_name(&args.name).map_err(ErrorData::from)?;
390 let span = tracing::info_span!(
391 "forget",
392 name = %args.name,
393 scope = ?args.scope,
394 );
395 let state = Arc::clone(&self.state);
396 async move {
397 let scope = parse_scope(args.scope.as_deref()).map_err(ErrorData::from)?;
398
399 let start = Instant::now();
400
401 state
403 .repo
404 .delete_memory(&args.name, &scope)
405 .await
406 .map_err(ErrorData::from)?;
407
408 let qualified_name = format!("{}/{}", scope.dir_prefix(), args.name);
410 if let Err(e) = state.index.remove(&scope, &qualified_name) {
411 warn!(name = %args.name, error = %e, "vector removal failed during forget; stale entry will be skipped at recall");
412 }
413
414 info!(
415 ms = start.elapsed().as_millis(),
416 name = %args.name,
417 "memory forgotten"
418 );
419
420 Ok("ok".to_string())
421 }
422 .instrument(span)
423 .await
424 }
425
426 #[tool(
435 name = "edit",
436 description = "Edit an existing memory. Supports partial updates — omit content or \
437 tags to preserve existing values. Re-embeds and re-indexes the memory. Use scope \
438 'project:<basename-of-your-cwd>' for project-scoped memories. Returns the memory ID."
439 )]
440 async fn edit(&self, Parameters(args): Parameters<EditArgs>) -> Result<String, ErrorData> {
441 validate_name(&args.name).map_err(ErrorData::from)?;
442 if let Some(ref content) = args.content {
443 if content.len() > MAX_CONTENT_SIZE {
444 return Err(ErrorData::from(crate::error::MemoryError::InvalidInput {
445 reason: format!(
446 "content size {} exceeds maximum of {} bytes",
447 content.len(),
448 MAX_CONTENT_SIZE
449 ),
450 }));
451 }
452 }
453 let span = tracing::info_span!(
454 "edit",
455 name = %args.name,
456 scope = ?args.scope,
457 );
458 let state = Arc::clone(&self.state);
459 async move {
460 let scope = parse_scope(args.scope.as_deref()).map_err(ErrorData::from)?;
461
462 let start = Instant::now();
463
464 let content_changed = args.content.is_some();
466
467 let mut memory = state
469 .repo
470 .read_memory(&args.name, &scope)
471 .await
472 .map_err(ErrorData::from)?;
473
474 if let Some(content) = args.content {
476 memory.content = content;
477 }
478 if let Some(tags) = args.tags {
479 memory.metadata.tags = tags;
480 }
481 memory.metadata.updated_at = Utc::now();
482
483 if content_changed {
486 let qualified_name =
487 format!("{}/{}", memory.metadata.scope.dir_prefix(), memory.name);
488
489 let vector = state
491 .embedding
492 .embed_one(&memory.content)
493 .await
494 .map_err(ErrorData::from)?;
495
496 state
497 .index
498 .add(&scope, &vector, qualified_name)
499 .map_err(ErrorData::from)?;
500 }
501
502 state
504 .repo
505 .save_memory(&memory)
506 .await
507 .map_err(ErrorData::from)?;
508
509 info!(
510 ms = start.elapsed().as_millis(),
511 name = %args.name,
512 content_changed,
513 "memory edited"
514 );
515
516 Ok(serde_json::json!({
517 "id": memory.id,
518 "name": memory.name,
519 "scope": memory.metadata.scope.to_string(),
520 })
521 .to_string())
522 }
523 .instrument(span)
524 .await
525 }
526
527 #[tool(
532 name = "list",
533 description = "List stored memories. Pass 'project:<basename-of-your-cwd>' for project + global memories, \
534 'global' for global-only, or 'all' for everything. Omitting scope defaults to global-only. \
535 Returns a JSON array of memory summaries without full content."
536 )]
537 async fn list(&self, Parameters(args): Parameters<ListArgs>) -> Result<String, ErrorData> {
538 let span = tracing::info_span!("list", scope = ?args.scope);
539 let state = Arc::clone(&self.state);
540 async move {
541 let scope_filter =
542 parse_scope_filter(args.scope.as_deref()).map_err(ErrorData::from)?;
543
544 let start = Instant::now();
545 let memories = match &scope_filter {
546 ScopeFilter::GlobalOnly => state
547 .repo
548 .list_memories(Some(&Scope::Global))
549 .await
550 .map_err(ErrorData::from)?,
551 ScopeFilter::All => state
552 .repo
553 .list_memories(None)
554 .await
555 .map_err(ErrorData::from)?,
556 ScopeFilter::ProjectAndGlobal(project_name) => {
557 let project_scope = Scope::Project(project_name.clone());
558 let mut global = state
559 .repo
560 .list_memories(Some(&Scope::Global))
561 .await
562 .map_err(ErrorData::from)?;
563 let project = state
564 .repo
565 .list_memories(Some(&project_scope))
566 .await
567 .map_err(ErrorData::from)?;
568 global.extend(project);
569 global
570 }
571 };
572 info!(
573 ms = start.elapsed().as_millis(),
574 count = memories.len(),
575 "listed memories"
576 );
577
578 let summaries: Vec<serde_json::Value> = memories
579 .into_iter()
580 .map(|m| {
581 serde_json::json!({
582 "id": m.id,
583 "name": m.name,
584 "scope": m.metadata.scope.to_string(),
585 "tags": m.metadata.tags,
586 "created_at": m.metadata.created_at,
587 "updated_at": m.metadata.updated_at,
588 })
589 })
590 .collect();
591
592 Ok(serde_json::json!({
593 "memories": summaries,
594 "count": summaries.len(),
595 })
596 .to_string())
597 }
598 .instrument(span)
599 .await
600 }
601
602 #[tool(
607 name = "read",
608 description = "Read a specific memory by name. Use scope 'project:<basename-of-your-cwd>' for \
609 project-scoped memories or omit for global. Returns the full markdown content and metadata \
610 (id, scope, tags, timestamps) as a JSON object."
611 )]
612 async fn read(&self, Parameters(args): Parameters<ReadArgs>) -> Result<String, ErrorData> {
613 validate_name(&args.name).map_err(ErrorData::from)?;
614 let span = tracing::info_span!("read", name = %args.name, scope = ?args.scope);
615 let state = Arc::clone(&self.state);
616 async move {
617 let scope = parse_scope(args.scope.as_deref()).map_err(ErrorData::from)?;
618
619 let start = Instant::now();
620 let memory = state
621 .repo
622 .read_memory(&args.name, &scope)
623 .await
624 .map_err(ErrorData::from)?;
625 info!(
626 ms = start.elapsed().as_millis(),
627 name = %args.name,
628 "read memory"
629 );
630
631 Ok(serde_json::json!({
632 "id": memory.id,
633 "name": memory.name,
634 "scope": memory.metadata.scope.to_string(),
635 "tags": memory.metadata.tags,
636 "content": memory.content,
637 "source": memory.metadata.source,
638 "created_at": memory.metadata.created_at,
639 "updated_at": memory.metadata.updated_at,
640 })
641 .to_string())
642 }
643 .instrument(span)
644 .await
645 }
646
647 #[tool(
655 name = "sync",
656 description = "Sync the memory repo with the git remote (push/pull). Requires \
657 MEMORY_MCP_GITHUB_TOKEN or a token file. Returns a status message."
658 )]
659 async fn sync(&self, Parameters(args): Parameters<SyncArgs>) -> Result<String, ErrorData> {
660 let pull_first = args.pull_first.unwrap_or(true);
661 let span = tracing::info_span!("sync", pull_first);
662 let state = Arc::clone(&self.state);
663 async move {
664 let start = Instant::now();
665 let branch = &state.branch;
666
667 let mut has_remote = true;
670
671 let mut reindex_stats: Option<ReindexStats> = None;
672
673 let pull_status = if pull_first {
674 let result = state
675 .repo
676 .pull(&state.auth, branch)
677 .await
678 .map_err(ErrorData::from)?;
679
680 let mut oid_range: Option<([u8; 20], [u8; 20])> = None;
681 let status = match result {
682 PullResult::NoRemote => {
683 has_remote = false;
684 "no-remote".to_string()
685 }
686 PullResult::UpToDate => "up-to-date".to_string(),
687 PullResult::FastForward { old_head, new_head } => {
688 oid_range = Some((old_head, new_head));
689 "fast-forward".to_string()
690 }
691 PullResult::Merged {
692 conflicts_resolved,
693 old_head,
694 new_head,
695 } => {
696 oid_range = Some((old_head, new_head));
697 format!("merged ({} conflicts resolved)", conflicts_resolved)
698 }
699 };
700
701 if let Some((old_head, new_head)) = oid_range {
702 let repo = Arc::clone(&state.repo);
703 let changes = tokio::task::spawn_blocking(move || {
704 repo.diff_changed_memories(old_head, new_head)
705 })
706 .await
707 .map_err(|e| MemoryError::Join(e.to_string()))
708 .map_err(ErrorData::from)?
709 .map_err(ErrorData::from)?;
710
711 if !changes.is_empty() {
712 let stats = incremental_reindex(
713 &state.repo,
714 state.embedding.as_ref(),
715 &state.index,
716 &changes,
717 )
718 .await;
719 info!(
720 added = stats.added,
721 updated = stats.updated,
722 removed = stats.removed,
723 errors = stats.errors,
724 "incremental reindex complete"
725 );
726 reindex_stats = Some(stats);
727 }
728 }
729
730 status
731 } else {
732 "skipped".to_string()
733 };
734
735 if has_remote {
736 state
737 .repo
738 .push(&state.auth, branch)
739 .await
740 .map_err(ErrorData::from)?;
741 }
742
743 info!(
744 ms = start.elapsed().as_millis(),
745 pull_first,
746 pull_status = %pull_status,
747 "sync complete"
748 );
749
750 let mut response = serde_json::json!({
751 "status": "sync complete",
752 "pull": pull_status,
753 "branch": branch,
754 });
755
756 if let Some(stats) = reindex_stats {
757 response["reindex"] = serde_json::json!({
758 "added": stats.added,
759 "updated": stats.updated,
760 "removed": stats.removed,
761 "errors": stats.errors,
762 });
763 }
764
765 Ok(response.to_string())
766 }
767 .instrument(span)
768 .await
769 }
770}
771
772#[tool_handler]
773impl ServerHandler for MemoryServer {
774 fn get_info(&self) -> ServerInfo {
775 ServerInfo::new(ServerCapabilities::builder().enable_tools().build()).with_instructions(
776 "A semantic memory system for AI coding agents. Memories are stored as markdown files \
777 in a git repository and indexed for semantic retrieval. Use `remember` to store, `recall` \
778 to search, `read` to fetch a specific memory, `edit` to update, `forget` to delete, \
779 `list` to browse, and `sync` to push/pull the remote.\n\n\
780 Scope convention: always pass scope='project:<basename-of-your-cwd>' when working within \
781 a project. This returns project memories alongside global ones. Omitting scope defaults to \
782 global-only for queries (recall, list) and targets a single memory for point operations \
783 (remember, edit, read, forget). Use scope='all' to search across all projects."
784 .to_string(),
785 )
786 }
787}