1use std::{sync::Arc, time::Instant};
2
3const SNIPPET_MAX_CHARS: usize = 500;
6
7use chrono::Utc;
8use rmcp::{
9 handler::server::{router::tool::ToolRouter, wrapper::Parameters},
10 model::{ErrorData, ServerCapabilities, ServerInfo},
11 tool, tool_handler, tool_router, ServerHandler,
12};
13use tracing::{info, warn, Instrument};
14
15use crate::{
16 embedding::EmbeddingBackend,
17 error::MemoryError,
18 index::ScopedIndex,
19 repo::MemoryRepo,
20 types::{
21 parse_qualified_name, parse_scope, parse_scope_filter, validate_name, AppState,
22 ChangedMemories, EditArgs, ForgetArgs, ListArgs, Memory, MemoryMetadata, PullResult,
23 ReadArgs, RecallArgs, ReindexStats, RememberArgs, Scope, ScopeFilter, SyncArgs,
24 },
25};
26
27#[derive(Clone)]
32pub struct MemoryServer {
33 state: Arc<AppState>,
34 #[allow(dead_code)]
37 tool_router: ToolRouter<Self>,
38}
39
40const MAX_CONTENT_SIZE: usize = 1_048_576;
42
43async fn incremental_reindex(
52 repo: &Arc<MemoryRepo>,
53 embedding: &dyn EmbeddingBackend,
54 index: &ScopedIndex,
55 changes: &ChangedMemories,
56) -> ReindexStats {
57 let mut stats = ReindexStats::default();
58
59 for name in &changes.removed {
61 match parse_qualified_name(name) {
62 Ok((scope, _)) => {
63 if let Err(e) = index.remove(&scope, name) {
64 warn!(
65 qualified_name = %name,
66 error = %e,
67 "incremental_reindex: failed to remove vector; skipping"
68 );
69 stats.errors += 1;
70 } else {
71 stats.removed += 1;
72 }
73 }
74 Err(e) => {
75 warn!(
76 qualified_name = %name,
77 error = %e,
78 "incremental_reindex: cannot parse qualified name for removal; skipping"
79 );
80 }
82 }
83 }
85
86 let mut pairs: Vec<(Scope, String, String)> = Vec::new(); for qualified in &changes.upserted {
91 match parse_qualified_name(qualified) {
92 Ok((scope, name)) => pairs.push((scope, name, qualified.clone())),
93 Err(e) => {
94 warn!(
95 qualified_name = %qualified,
96 error = %e,
97 "incremental_reindex: cannot parse qualified name; skipping"
98 );
99 stats.errors += 1;
100 }
101 }
102 }
103
104 let mut to_embed: Vec<(Scope, String, String)> = Vec::new();
107 for (scope, name, qualified) in &pairs {
108 let memory = match repo.read_memory(name, scope).await {
109 Ok(m) => m,
110 Err(e) => {
111 warn!(
112 qualified_name = %qualified,
113 error = %e,
114 "incremental_reindex: failed to read memory; skipping"
115 );
116 stats.errors += 1;
117 continue;
118 }
119 };
120 to_embed.push((scope.clone(), qualified.clone(), memory.content));
121 }
122
123 if to_embed.is_empty() {
124 return stats;
125 }
126
127 let contents: Vec<String> = to_embed.iter().map(|(_, _, c)| c.clone()).collect();
129 let vectors = match embedding.embed(&contents).await {
130 Ok(v) => v,
131 Err(batch_err) => {
132 warn!(error = %batch_err, "incremental_reindex: batch embed failed; falling back to per-item");
133 let mut vecs: Vec<Vec<f32>> = Vec::with_capacity(contents.len());
134 let mut failed: Vec<usize> = Vec::new();
135 for (i, content) in contents.iter().enumerate() {
136 match embedding.embed(std::slice::from_ref(content)).await {
137 Ok(mut v) => vecs.push(v.remove(0)),
138 Err(e) => {
139 warn!(
140 error = %e,
141 qualified_name = %to_embed[i].1,
142 "incremental_reindex: per-item embed failed; skipping"
143 );
144 failed.push(i);
145 stats.errors += 1;
146 }
147 }
148 }
149 for &i in failed.iter().rev() {
151 to_embed.remove(i);
152 }
153 vecs
154 }
155 };
156
157 for ((scope, qualified_name, _), vector) in to_embed.iter().zip(vectors.iter()) {
159 let is_update = index.find_key_by_name(qualified_name).is_some();
160
161 match index.add(scope, vector, qualified_name.clone()) {
162 Ok(_) => {}
163 Err(e) => {
164 warn!(
165 qualified_name = %qualified_name,
166 error = %e,
167 "incremental_reindex: add failed; skipping"
168 );
169 stats.errors += 1;
170 continue;
171 }
172 }
173
174 if is_update {
175 stats.updated += 1;
176 } else {
177 stats.added += 1;
178 }
179 }
180
181 stats
182}
183
184#[tool_router]
185impl MemoryServer {
186 pub fn new(state: Arc<AppState>) -> Self {
188 Self {
189 state,
190 tool_router: Self::tool_router(),
191 }
192 }
193
194 #[tool(
201 name = "remember",
202 description = "Store a new memory. Saves the content to the git-backed repository and \
203 indexes it for semantic search. Use scope 'project:<basename-of-your-cwd>' for \
204 project-specific memories or omit for global. Returns the assigned memory ID. \
205 IMPORTANT: Never store credentials, API keys, tokens, passwords, or other secrets — \
206 memories are plaintext files in a git repo and may be synced to a remote."
207 )]
208 async fn remember(
209 &self,
210 Parameters(args): Parameters<RememberArgs>,
211 ) -> Result<String, ErrorData> {
212 validate_name(&args.name).map_err(ErrorData::from)?;
213 if args.content.len() > MAX_CONTENT_SIZE {
214 return Err(ErrorData::from(crate::error::MemoryError::InvalidInput {
215 reason: format!(
216 "content size {} exceeds maximum of {} bytes",
217 args.content.len(),
218 MAX_CONTENT_SIZE
219 ),
220 }));
221 }
222 let span = tracing::info_span!(
223 "remember",
224 name = %args.name,
225 scope = ?args.scope,
226 );
227 let state = Arc::clone(&self.state);
228 async move {
229 let scope = parse_scope(args.scope.as_deref()).map_err(ErrorData::from)?;
230 let metadata = MemoryMetadata::new(scope.clone(), args.tags, args.source);
231 let memory = Memory::new(args.name, args.content, metadata);
232
233 let start = Instant::now();
237 let vector = state
238 .embedding
239 .embed_one(&memory.content)
240 .await
241 .map_err(ErrorData::from)?;
242 info!(embed_ms = start.elapsed().as_millis(), "embedded");
243
244 let qualified_name = format!("{}/{}", memory.metadata.scope.dir_prefix(), memory.name);
245
246 state
247 .index
248 .add(&scope, &vector, qualified_name)
249 .map_err(ErrorData::from)?;
250
251 let start = Instant::now();
252 state
253 .repo
254 .save_memory(&memory)
255 .await
256 .map_err(ErrorData::from)?;
257 info!(repo_ms = start.elapsed().as_millis(), "saved to repo");
258
259 Ok(serde_json::json!({
260 "id": memory.id,
261 "name": memory.name,
262 "scope": memory.metadata.scope.to_string(),
263 })
264 .to_string())
265 }
266 .instrument(span)
267 .await
268 }
269
270 #[tool(
277 name = "recall",
278 description = "Search memories by semantic similarity. Embeds the query and returns the top matching memories as a JSON array \
279 with name, scope, tags, and a content snippet (max 500 chars).\n\n\
280 Each result includes `truncated` (bool) and `content_length` (total character count). \
281 When `truncated` is true, the snippet is incomplete — use the `read` tool with the memory's name and scope \
282 to retrieve the full content before acting on it.\n\n\
283 Scope: pass 'project:<basename-of-your-cwd>' to search your current project + global memories, \
284 'global' for global-only, or 'all' to search everything. Omitting scope defaults to global-only."
285 )]
286 async fn recall(&self, Parameters(args): Parameters<RecallArgs>) -> Result<String, ErrorData> {
287 let span = tracing::info_span!(
288 "recall",
289 query = %args.query,
290 scope = ?args.scope,
291 limit = ?args.limit,
292 );
293 let state = Arc::clone(&self.state);
294 async move {
295 let scope_filter =
297 parse_scope_filter(args.scope.as_deref()).map_err(ErrorData::from)?;
298
299 let limit = args.limit.unwrap_or(5).min(100);
300
301 let start = Instant::now();
302 let query_vector = state
303 .embedding
304 .embed_one(&args.query)
305 .await
306 .map_err(ErrorData::from)?;
307 info!(embed_ms = start.elapsed().as_millis(), "query embedded");
308
309 let start = Instant::now();
310 let results = state
311 .index
312 .search(&scope_filter, &query_vector, limit)
313 .map_err(ErrorData::from)?;
314 info!(
315 search_ms = start.elapsed().as_millis(),
316 candidates = results.len(),
317 "index searched"
318 );
319
320 let pre_filter_count = results.len();
321 let mut results_vec = Vec::new();
322 let mut skipped_errors: usize = 0;
323
324 for (_key, qualified_name, distance) in results {
325 if results_vec.len() >= limit {
328 break;
329 }
330 let (scope, name) = match parse_qualified_name(&qualified_name) {
331 Ok(pair) => pair,
332 Err(e) => {
333 warn!(
334 qualified_name = %qualified_name,
335 error = %e,
336 "could not parse qualified name from index; skipping"
337 );
338 skipped_errors += 1;
339 continue;
340 }
341 };
342
343 let memory = match state.repo.read_memory(&name, &scope).await {
345 Ok(m) => m,
346 Err(e) => {
347 warn!(
348 name = %name,
349 error = %e,
350 "could not read memory from repo (deleted?); skipping"
351 );
352 skipped_errors += 1;
353 continue;
354 }
355 };
356
357 let (snippet, content_length, truncated) = build_snippet(&memory.content);
358
359 results_vec.push(serde_json::json!({
360 "id": memory.id,
361 "name": memory.name,
362 "scope": memory.metadata.scope.to_string(),
363 "tags": memory.metadata.tags,
364 "content": snippet,
365 "content_length": content_length,
366 "truncated": truncated,
367 "distance": distance,
368 }));
369 }
370
371 info!(
372 returned = results_vec.len(),
373 skipped_errors, "recall complete"
374 );
375
376 Ok(serde_json::json!({
377 "results": results_vec,
378 "count": results_vec.len(),
379 "limit": limit,
380 "pre_filter_count": pre_filter_count,
381 "skipped_errors": skipped_errors,
382 })
383 .to_string())
384 }
385 .instrument(span)
386 .await
387 }
388
389 #[tool(
396 name = "forget",
397 description = "Delete a memory by name. Use scope 'project:<basename-of-your-cwd>' for project-scoped \
398 memories or omit for global. Removes the file from git and the vector from the search index. \
399 Returns 'ok' on success."
400 )]
401 async fn forget(&self, Parameters(args): Parameters<ForgetArgs>) -> Result<String, ErrorData> {
402 validate_name(&args.name).map_err(ErrorData::from)?;
403 let span = tracing::info_span!(
404 "forget",
405 name = %args.name,
406 scope = ?args.scope,
407 );
408 let state = Arc::clone(&self.state);
409 async move {
410 let scope = parse_scope(args.scope.as_deref()).map_err(ErrorData::from)?;
411
412 let start = Instant::now();
413
414 state
416 .repo
417 .delete_memory(&args.name, &scope)
418 .await
419 .map_err(ErrorData::from)?;
420
421 let qualified_name = format!("{}/{}", scope.dir_prefix(), args.name);
423 if let Err(e) = state.index.remove(&scope, &qualified_name) {
424 warn!(name = %args.name, error = %e, "vector removal failed during forget; stale entry will be skipped at recall");
425 }
426
427 info!(
428 ms = start.elapsed().as_millis(),
429 name = %args.name,
430 "memory forgotten"
431 );
432
433 Ok("ok".to_string())
434 }
435 .instrument(span)
436 .await
437 }
438
439 #[tool(
448 name = "edit",
449 description = "Edit an existing memory. Supports partial updates — omit content or \
450 tags to preserve existing values. Re-embeds and re-indexes the memory. Use scope \
451 'project:<basename-of-your-cwd>' for project-scoped memories. Returns the memory ID. \
452 IMPORTANT: Never store credentials, API keys, tokens, passwords, or other secrets — \
453 memories are plaintext files in a git repo and may be synced to a remote."
454 )]
455 async fn edit(&self, Parameters(args): Parameters<EditArgs>) -> Result<String, ErrorData> {
456 validate_name(&args.name).map_err(ErrorData::from)?;
457 if let Some(ref content) = args.content {
458 if content.len() > MAX_CONTENT_SIZE {
459 return Err(ErrorData::from(crate::error::MemoryError::InvalidInput {
460 reason: format!(
461 "content size {} exceeds maximum of {} bytes",
462 content.len(),
463 MAX_CONTENT_SIZE
464 ),
465 }));
466 }
467 }
468 let span = tracing::info_span!(
469 "edit",
470 name = %args.name,
471 scope = ?args.scope,
472 );
473 let state = Arc::clone(&self.state);
474 async move {
475 let scope = parse_scope(args.scope.as_deref()).map_err(ErrorData::from)?;
476
477 let start = Instant::now();
478
479 let content_changed = args.content.is_some();
481
482 let mut memory = state
484 .repo
485 .read_memory(&args.name, &scope)
486 .await
487 .map_err(ErrorData::from)?;
488
489 if let Some(content) = args.content {
491 memory.content = content;
492 }
493 if let Some(tags) = args.tags {
494 memory.metadata.tags = tags;
495 }
496 memory.metadata.updated_at = Utc::now();
497
498 if content_changed {
501 let qualified_name =
502 format!("{}/{}", memory.metadata.scope.dir_prefix(), memory.name);
503
504 let vector = state
506 .embedding
507 .embed_one(&memory.content)
508 .await
509 .map_err(ErrorData::from)?;
510
511 state
512 .index
513 .add(&scope, &vector, qualified_name)
514 .map_err(ErrorData::from)?;
515 }
516
517 state
519 .repo
520 .save_memory(&memory)
521 .await
522 .map_err(ErrorData::from)?;
523
524 info!(
525 ms = start.elapsed().as_millis(),
526 name = %args.name,
527 content_changed,
528 "memory edited"
529 );
530
531 Ok(serde_json::json!({
532 "id": memory.id,
533 "name": memory.name,
534 "scope": memory.metadata.scope.to_string(),
535 })
536 .to_string())
537 }
538 .instrument(span)
539 .await
540 }
541
542 #[tool(
547 name = "list",
548 description = "List stored memories. Pass 'project:<basename-of-your-cwd>' for project + global memories, \
549 'global' for global-only, or 'all' for everything. Omitting scope defaults to global-only. \
550 Returns a JSON array of memory summaries without full content."
551 )]
552 async fn list(&self, Parameters(args): Parameters<ListArgs>) -> Result<String, ErrorData> {
553 let span = tracing::info_span!("list", scope = ?args.scope);
554 let state = Arc::clone(&self.state);
555 async move {
556 let scope_filter =
557 parse_scope_filter(args.scope.as_deref()).map_err(ErrorData::from)?;
558
559 let start = Instant::now();
560 let memories = match &scope_filter {
561 ScopeFilter::GlobalOnly => state
562 .repo
563 .list_memories(Some(&Scope::Global))
564 .await
565 .map_err(ErrorData::from)?,
566 ScopeFilter::All => state
567 .repo
568 .list_memories(None)
569 .await
570 .map_err(ErrorData::from)?,
571 ScopeFilter::ProjectAndGlobal(project_name) => {
572 let project_scope = Scope::Project(project_name.clone());
573 let mut global = state
574 .repo
575 .list_memories(Some(&Scope::Global))
576 .await
577 .map_err(ErrorData::from)?;
578 let project = state
579 .repo
580 .list_memories(Some(&project_scope))
581 .await
582 .map_err(ErrorData::from)?;
583 global.extend(project);
584 global
585 }
586 };
587 info!(
588 ms = start.elapsed().as_millis(),
589 count = memories.len(),
590 "listed memories"
591 );
592
593 let summaries: Vec<serde_json::Value> = memories
594 .into_iter()
595 .map(|m| {
596 serde_json::json!({
597 "id": m.id,
598 "name": m.name,
599 "scope": m.metadata.scope.to_string(),
600 "tags": m.metadata.tags,
601 "created_at": m.metadata.created_at,
602 "updated_at": m.metadata.updated_at,
603 })
604 })
605 .collect();
606
607 Ok(serde_json::json!({
608 "memories": summaries,
609 "count": summaries.len(),
610 })
611 .to_string())
612 }
613 .instrument(span)
614 .await
615 }
616
617 #[tool(
622 name = "read",
623 description = "Read a specific memory by name. Use scope 'project:<basename-of-your-cwd>' for \
624 project-scoped memories or omit for global. Returns the full markdown content and metadata \
625 (id, scope, tags, timestamps) as a JSON object."
626 )]
627 async fn read(&self, Parameters(args): Parameters<ReadArgs>) -> Result<String, ErrorData> {
628 validate_name(&args.name).map_err(ErrorData::from)?;
629 let span = tracing::info_span!("read", name = %args.name, scope = ?args.scope);
630 let state = Arc::clone(&self.state);
631 async move {
632 let scope = parse_scope(args.scope.as_deref()).map_err(ErrorData::from)?;
633
634 let start = Instant::now();
635 let memory = state
636 .repo
637 .read_memory(&args.name, &scope)
638 .await
639 .map_err(ErrorData::from)?;
640 info!(
641 ms = start.elapsed().as_millis(),
642 name = %args.name,
643 "read memory"
644 );
645
646 Ok(serde_json::json!({
647 "id": memory.id,
648 "name": memory.name,
649 "scope": memory.metadata.scope.to_string(),
650 "tags": memory.metadata.tags,
651 "content": memory.content,
652 "source": memory.metadata.source,
653 "created_at": memory.metadata.created_at,
654 "updated_at": memory.metadata.updated_at,
655 })
656 .to_string())
657 }
658 .instrument(span)
659 .await
660 }
661
662 #[tool(
670 name = "sync",
671 description = "Sync the memory repo with the git remote (push/pull). Requires \
672 MEMORY_MCP_GITHUB_TOKEN or a token file. Returns a status message."
673 )]
674 async fn sync(&self, Parameters(args): Parameters<SyncArgs>) -> Result<String, ErrorData> {
675 let pull_first = args.pull_first.unwrap_or(true);
676 let span = tracing::info_span!("sync", pull_first);
677 let state = Arc::clone(&self.state);
678 async move {
679 let start = Instant::now();
680 let branch = &state.branch;
681
682 let mut has_remote = true;
685
686 let mut reindex_stats: Option<ReindexStats> = None;
687
688 let pull_status = if pull_first {
689 let result = state
690 .repo
691 .pull(&state.auth, branch)
692 .await
693 .map_err(ErrorData::from)?;
694
695 let mut oid_range: Option<([u8; 20], [u8; 20])> = None;
696 let status = match result {
697 PullResult::NoRemote => {
698 has_remote = false;
699 "no-remote".to_string()
700 }
701 PullResult::UpToDate => "up-to-date".to_string(),
702 PullResult::FastForward { old_head, new_head } => {
703 oid_range = Some((old_head, new_head));
704 "fast-forward".to_string()
705 }
706 PullResult::Merged {
707 conflicts_resolved,
708 old_head,
709 new_head,
710 } => {
711 oid_range = Some((old_head, new_head));
712 format!("merged ({} conflicts resolved)", conflicts_resolved)
713 }
714 };
715
716 if let Some((old_head, new_head)) = oid_range {
717 let repo = Arc::clone(&state.repo);
718 let changes = tokio::task::spawn_blocking(move || {
719 repo.diff_changed_memories(old_head, new_head)
720 })
721 .await
722 .map_err(|e| MemoryError::Join(e.to_string()))
723 .map_err(ErrorData::from)?
724 .map_err(ErrorData::from)?;
725
726 if !changes.is_empty() {
727 let stats = incremental_reindex(
728 &state.repo,
729 state.embedding.as_ref(),
730 &state.index,
731 &changes,
732 )
733 .await;
734 info!(
735 added = stats.added,
736 updated = stats.updated,
737 removed = stats.removed,
738 errors = stats.errors,
739 "incremental reindex complete"
740 );
741 reindex_stats = Some(stats);
742 }
743 }
744
745 status
746 } else {
747 "skipped".to_string()
748 };
749
750 if has_remote {
751 state
752 .repo
753 .push(&state.auth, branch)
754 .await
755 .map_err(ErrorData::from)?;
756 }
757
758 info!(
759 ms = start.elapsed().as_millis(),
760 pull_first,
761 pull_status = %pull_status,
762 "sync complete"
763 );
764
765 let mut response = serde_json::json!({
766 "status": "sync complete",
767 "pull": pull_status,
768 "branch": branch,
769 });
770
771 if let Some(stats) = reindex_stats {
772 response["reindex"] = serde_json::json!({
773 "added": stats.added,
774 "updated": stats.updated,
775 "removed": stats.removed,
776 "errors": stats.errors,
777 });
778 }
779
780 Ok(response.to_string())
781 }
782 .instrument(span)
783 .await
784 }
785}
786
787#[tool_handler]
788impl ServerHandler for MemoryServer {
789 fn get_info(&self) -> ServerInfo {
790 ServerInfo::new(ServerCapabilities::builder().enable_tools().build()).with_instructions(
791 "A semantic memory system for AI coding agents. Memories are stored as markdown files \
792 in a git repository and indexed for semantic retrieval. Use `remember` to store, `recall` \
793 to search, `read` to fetch a specific memory, `edit` to update, `forget` to delete, \
794 `list` to browse, and `sync` to push/pull the remote.\n\n\
795 Scope convention: always pass scope='project:<basename-of-your-cwd>' when working within \
796 a project. This returns project memories alongside global ones. Omitting scope defaults to \
797 global-only for queries (recall, list) and targets a single memory for point operations \
798 (remember, edit, read, forget). Use scope='all' to search across all projects.\n\n\
799 IMPORTANT: Never store credentials, API keys, tokens, passwords, or other secrets in \
800 memory content. Memories are stored as plaintext markdown files committed to a git \
801 repository and may be synced to a remote. Treat all memory content as public."
802 .to_string(),
803 )
804 }
805}
806
807fn build_snippet(content: &str) -> (String, usize, bool) {
809 let content_length = content.chars().count();
810 let truncated = content_length > SNIPPET_MAX_CHARS;
811 let snippet: String = content.chars().take(SNIPPET_MAX_CHARS).collect();
812 (snippet, content_length, truncated)
813}
814
815#[cfg(test)]
816mod tests {
817 use super::*;
818
819 #[test]
820 fn snippet_short_content_not_truncated() {
821 let content = "Hello, world!";
822 let (snippet, content_length, truncated) = build_snippet(content);
823 assert_eq!(snippet, "Hello, world!");
824 assert_eq!(content_length, 13);
825 assert!(!truncated);
826 }
827
828 #[test]
829 fn snippet_exact_limit_not_truncated() {
830 let content: String = "a".repeat(SNIPPET_MAX_CHARS);
831 let (snippet, content_length, truncated) = build_snippet(&content);
832 assert_eq!(snippet, content);
833 assert_eq!(content_length, SNIPPET_MAX_CHARS);
834 assert!(!truncated);
835 }
836
837 #[test]
838 fn snippet_over_limit_is_truncated() {
839 let content: String = "b".repeat(SNIPPET_MAX_CHARS + 100);
840 let (snippet, content_length, truncated) = build_snippet(&content);
841 assert_eq!(snippet.chars().count(), SNIPPET_MAX_CHARS);
842 assert_eq!(content_length, SNIPPET_MAX_CHARS + 100);
843 assert!(truncated);
844 }
845
846 #[test]
847 fn snippet_counts_unicode_chars_not_bytes() {
848 let emoji_content: String = "\u{1F600}".repeat(SNIPPET_MAX_CHARS + 1);
850 let (snippet, content_length, truncated) = build_snippet(&emoji_content);
851 assert_eq!(snippet.chars().count(), SNIPPET_MAX_CHARS);
852 assert_eq!(content_length, SNIPPET_MAX_CHARS + 1);
853 assert!(truncated);
854 }
855
856 #[test]
857 fn snippet_empty_content() {
858 let (snippet, content_length, truncated) = build_snippet("");
859 assert_eq!(snippet, "");
860 assert_eq!(content_length, 0);
861 assert!(!truncated);
862 }
863}