1use std::{
2 path::{Path, PathBuf},
3 sync::{Arc, Mutex},
4};
5
6use git2::{build::CheckoutBuilder, ErrorCode, MergeOptions, Repository, Signature};
7use tracing::{info, warn};
8
9use secrecy::{ExposeSecret, SecretString};
10
11use crate::{
12 auth::AuthProvider,
13 error::MemoryError,
14 types::{validate_name, ChangedMemories, Memory, PullResult, Scope},
15};
16
17fn redact_url(url: &str) -> String {
25 if let Some(at_pos) = url.find('@') {
26 if let Some(scheme_end) = url.find("://") {
27 let scheme = &url[..scheme_end + 3];
28 let after_at = &url[at_pos + 1..];
29 return format!("{}[REDACTED]@{}", scheme, after_at);
30 }
31 }
32 url.to_string()
33}
34
35fn capture_head_oid(repo: &git2::Repository) -> Result<[u8; 20], MemoryError> {
39 match repo.head() {
40 Ok(h) => {
41 let oid = h.peel_to_commit()?.id();
42 let mut buf = [0u8; 20];
43 buf.copy_from_slice(oid.as_bytes());
44 Ok(buf)
45 }
46 Err(e) if e.code() == ErrorCode::UnbornBranch || e.code() == ErrorCode::NotFound => {
48 Ok([0u8; 20])
49 }
50 Err(e) => Err(MemoryError::Git(e)),
51 }
52}
53
54fn fast_forward(
59 repo: &git2::Repository,
60 fetch_commit: &git2::AnnotatedCommit,
61 branch: &str,
62) -> Result<PullResult, MemoryError> {
63 let old_head = capture_head_oid(repo)?;
64
65 let refname = format!("refs/heads/{branch}");
66 let target_oid = fetch_commit.id();
67
68 match repo.find_reference(&refname) {
69 Ok(mut reference) => {
70 reference.set_target(target_oid, &format!("pull: fast-forward to {}", target_oid))?;
71 }
72 Err(e) if e.code() == ErrorCode::NotFound => {
73 repo.reference(
75 &refname,
76 target_oid,
77 true,
78 &format!("pull: create branch {} from fetch", branch),
79 )?;
80 }
81 Err(e) => return Err(MemoryError::Git(e)),
82 }
83
84 repo.set_head(&refname)?;
85 let mut checkout = CheckoutBuilder::default();
86 checkout.force();
87 repo.checkout_head(Some(&mut checkout))?;
88
89 let mut new_head = [0u8; 20];
90 new_head.copy_from_slice(target_oid.as_bytes());
91
92 info!("pull: fast-forwarded to {}", target_oid);
93 Ok(PullResult::FastForward { old_head, new_head })
94}
95
96fn build_auth_callbacks(token: SecretString) -> git2::RemoteCallbacks<'static> {
100 let mut callbacks = git2::RemoteCallbacks::new();
101 callbacks.credentials(move |_url, _username, _allowed| {
102 git2::Cred::userpass_plaintext("x-access-token", token.expose_secret())
103 });
104 callbacks
105}
106
107pub struct MemoryRepo {
109 inner: Mutex<Repository>,
110 root: PathBuf,
111}
112
113unsafe impl Send for MemoryRepo {}
117unsafe impl Sync for MemoryRepo {}
118
119impl MemoryRepo {
120 pub fn init_or_open(path: &Path, remote_url: Option<&str>) -> Result<Self, MemoryError> {
125 let repo = if path.join(".git").exists() {
126 Repository::open(path)?
127 } else {
128 let mut opts = git2::RepositoryInitOptions::new();
129 opts.initial_head("main");
130 let repo = Repository::init_opts(path, &opts)?;
131 let gitignore = path.join(".gitignore");
133 if !gitignore.exists() {
134 std::fs::write(&gitignore, ".memory-mcp-index/\n")?;
135 }
136 {
138 let mut index = repo.index()?;
139 index.add_path(Path::new(".gitignore"))?;
140 index.write()?;
141 let tree_oid = index.write_tree()?;
142 let tree = repo.find_tree(tree_oid)?;
143 let sig = Signature::now("memory-mcp", "memory-mcp@local")?;
144 repo.commit(
145 Some("HEAD"),
146 &sig,
147 &sig,
148 "chore: init repository",
149 &tree,
150 &[],
151 )?;
152 }
153 repo
154 };
155
156 if let Some(url) = remote_url {
158 match repo.find_remote("origin") {
159 Ok(existing) => {
160 let current_url = existing.url().unwrap_or("");
162 if current_url != url {
163 repo.remote_set_url("origin", url)?;
164 info!("updated origin remote URL to {}", redact_url(url));
165 }
166 }
167 Err(e) if e.code() == ErrorCode::NotFound => {
168 repo.remote("origin", url)?;
169 info!("created origin remote pointing at {}", redact_url(url));
170 }
171 Err(e) => return Err(MemoryError::Git(e)),
172 }
173 }
174
175 Ok(Self {
176 inner: Mutex::new(repo),
177 root: path.to_path_buf(),
178 })
179 }
180
181 fn memory_path(&self, name: &str, scope: &Scope) -> PathBuf {
183 self.root
184 .join(scope.dir_prefix())
185 .join(format!("{}.md", name))
186 }
187
188 pub async fn save_memory(self: &Arc<Self>, memory: &Memory) -> Result<(), MemoryError> {
193 validate_name(&memory.name)?;
194 if let Scope::Project(ref project_name) = memory.metadata.scope {
195 validate_name(project_name)?;
196 }
197
198 let file_path = self.memory_path(&memory.name, &memory.metadata.scope);
199 self.assert_within_root(&file_path)?;
200
201 let arc = Arc::clone(self);
202 let memory = memory.clone();
203 tokio::task::spawn_blocking(move || -> Result<(), MemoryError> {
204 let repo = arc
205 .inner
206 .lock()
207 .expect("lock poisoned — prior panic corrupted state");
208
209 if let Some(parent) = file_path.parent() {
211 std::fs::create_dir_all(parent)?;
212 }
213
214 let markdown = memory.to_markdown()?;
215 arc.write_memory_file(&file_path, markdown.as_bytes())?;
216
217 arc.git_add_and_commit(
218 &repo,
219 &file_path,
220 &format!("chore: save memory '{}'", memory.name),
221 )?;
222 Ok(())
223 })
224 .await
225 .map_err(|e| MemoryError::Join(e.to_string()))?
226 }
227
228 pub async fn delete_memory(
230 self: &Arc<Self>,
231 name: &str,
232 scope: &Scope,
233 ) -> Result<(), MemoryError> {
234 validate_name(name)?;
235 if let Scope::Project(ref project_name) = *scope {
236 validate_name(project_name)?;
237 }
238
239 let file_path = self.memory_path(name, scope);
240 self.assert_within_root(&file_path)?;
241
242 let arc = Arc::clone(self);
243 let name = name.to_string();
244 let file_path_clone = file_path.clone();
245 tokio::task::spawn_blocking(move || -> Result<(), MemoryError> {
246 let repo = arc
247 .inner
248 .lock()
249 .expect("lock poisoned — prior panic corrupted state");
250
251 match std::fs::symlink_metadata(&file_path_clone) {
253 Err(_) => return Err(MemoryError::NotFound { name: name.clone() }),
254 Ok(m) if m.file_type().is_symlink() => {
255 return Err(MemoryError::InvalidInput {
256 reason: format!(
257 "path '{}' is a symlink, which is not permitted",
258 file_path_clone.display()
259 ),
260 });
261 }
262 Ok(_) => {}
263 }
264
265 std::fs::remove_file(&file_path_clone)?;
266 let relative =
268 file_path_clone
269 .strip_prefix(&arc.root)
270 .map_err(|e| MemoryError::InvalidInput {
271 reason: format!("path strip error: {}", e),
272 })?;
273 let mut index = repo.index()?;
274 index.remove_path(relative)?;
275 index.write()?;
276
277 let tree_oid = index.write_tree()?;
278 let tree = repo.find_tree(tree_oid)?;
279 let sig = arc.signature(&repo)?;
280 let message = format!("chore: delete memory '{}'", name);
281
282 match repo.head() {
283 Ok(head) => {
284 let parent_commit = head.peel_to_commit()?;
285 repo.commit(Some("HEAD"), &sig, &sig, &message, &tree, &[&parent_commit])?;
286 }
287 Err(e)
288 if e.code() == ErrorCode::UnbornBranch || e.code() == ErrorCode::NotFound =>
289 {
290 repo.commit(Some("HEAD"), &sig, &sig, &message, &tree, &[])?;
291 }
292 Err(e) => return Err(MemoryError::Git(e)),
293 }
294
295 Ok(())
296 })
297 .await
298 .map_err(|e| MemoryError::Join(e.to_string()))?
299 }
300
301 pub async fn read_memory(
303 self: &Arc<Self>,
304 name: &str,
305 scope: &Scope,
306 ) -> Result<Memory, MemoryError> {
307 validate_name(name)?;
308 if let Scope::Project(ref project_name) = *scope {
309 validate_name(project_name)?;
310 }
311
312 let file_path = self.memory_path(name, scope);
313 self.assert_within_root(&file_path)?;
314
315 let arc = Arc::clone(self);
316 let name = name.to_string();
317 tokio::task::spawn_blocking(move || -> Result<Memory, MemoryError> {
318 match std::fs::symlink_metadata(&file_path) {
320 Err(_) => return Err(MemoryError::NotFound { name }),
321 Ok(m) if m.file_type().is_symlink() => {
322 return Err(MemoryError::InvalidInput {
323 reason: format!(
324 "path '{}' is a symlink, which is not permitted",
325 file_path.display()
326 ),
327 });
328 }
329 Ok(_) => {}
330 }
331 let raw = arc.read_memory_file(&file_path)?;
332 Memory::from_markdown(&raw)
333 })
334 .await
335 .map_err(|e| MemoryError::Join(e.to_string()))?
336 }
337
338 pub async fn list_memories(
340 self: &Arc<Self>,
341 scope: Option<&Scope>,
342 ) -> Result<Vec<Memory>, MemoryError> {
343 let root = self.root.clone();
344 let scope_clone = scope.cloned();
345
346 tokio::task::spawn_blocking(move || -> Result<Vec<Memory>, MemoryError> {
347 let dirs: Vec<PathBuf> = match scope_clone.as_ref() {
348 Some(s) => vec![root.join(s.dir_prefix())],
349 None => {
350 let mut dirs = Vec::new();
352 let global = root.join("global");
353 if global.exists() {
354 dirs.push(global);
355 }
356 let projects = root.join("projects");
357 if projects.exists() {
358 for entry in std::fs::read_dir(&projects)? {
359 let entry = entry?;
360 if entry.file_type()?.is_dir() {
361 dirs.push(entry.path());
362 }
363 }
364 }
365 dirs
366 }
367 };
368
369 fn collect_md_files(dir: &Path, out: &mut Vec<Memory>) -> Result<(), MemoryError> {
370 if !dir.exists() {
371 return Ok(());
372 }
373 for entry in std::fs::read_dir(dir)? {
374 let entry = entry?;
375 let path = entry.path();
376 let ft = entry.file_type()?;
377 if ft.is_symlink() {
379 warn!(
380 "skipping symlink at {:?} — symlinks are not permitted in the memory store",
381 path
382 );
383 continue;
384 }
385 if ft.is_dir() {
386 collect_md_files(&path, out)?;
387 } else if path.extension().and_then(|e| e.to_str()) == Some("md") {
388 let raw = std::fs::read_to_string(&path)?;
389 match Memory::from_markdown(&raw) {
390 Ok(m) => out.push(m),
391 Err(e) => {
392 warn!("skipping {:?}: {}", path, e);
393 }
394 }
395 }
396 }
397 Ok(())
398 }
399
400 let mut memories = Vec::new();
401 for dir in dirs {
402 collect_md_files(&dir, &mut memories)?;
403 }
404
405 Ok(memories)
406 })
407 .await
408 .map_err(|e| MemoryError::Join(e.to_string()))?
409 }
410
411 pub async fn push(
416 self: &Arc<Self>,
417 auth: &AuthProvider,
418 branch: &str,
419 ) -> Result<(), MemoryError> {
420 let token_result = auth.resolve_token();
424 let arc = Arc::clone(self);
425 let branch = branch.to_string();
426
427 tokio::task::spawn_blocking(move || -> Result<(), MemoryError> {
428 let repo = arc
429 .inner
430 .lock()
431 .expect("lock poisoned — prior panic corrupted state");
432
433 let mut remote = match repo.find_remote("origin") {
434 Ok(r) => r,
435 Err(e) if e.code() == ErrorCode::NotFound => {
436 warn!("push: no origin remote configured — skipping (local-only mode)");
437 return Ok(());
438 }
439 Err(e) => return Err(MemoryError::Git(e)),
440 };
441
442 let token = token_result?;
444 let callbacks = build_auth_callbacks(token);
445 let mut push_opts = git2::PushOptions::new();
446 push_opts.remote_callbacks(callbacks);
447
448 let refspec = format!("refs/heads/{branch}:refs/heads/{branch}");
449 remote.push(&[&refspec], Some(&mut push_opts))?;
450 info!("pushed branch '{}' to origin", branch);
451 Ok(())
452 })
453 .await
454 .map_err(|e| MemoryError::Join(e.to_string()))?
455 }
456
457 fn merge_with_remote(
462 &self,
463 repo: &git2::Repository,
464 fetch_commit: &git2::AnnotatedCommit,
465 branch: &str,
466 ) -> Result<PullResult, MemoryError> {
467 let oid = repo.head()?.peel_to_commit()?.id();
471 let mut old_head = [0u8; 20];
472 old_head.copy_from_slice(oid.as_bytes());
473
474 let mut merge_opts = MergeOptions::new();
475 merge_opts.fail_on_conflict(false);
476 repo.merge(&[fetch_commit], Some(&mut merge_opts), None)?;
477
478 let mut index = repo.index()?;
479 let conflicts_resolved = if index.has_conflicts() {
480 self.resolve_conflicts_by_recency(repo, &mut index)?
481 } else {
482 0
483 };
484
485 if index.has_conflicts() {
489 let _ = repo.cleanup_state();
490 return Err(MemoryError::Internal(
491 "unresolved conflicts remain after auto-resolution".into(),
492 ));
493 }
494
495 index.write()?;
497 let tree_oid = index.write_tree()?;
498 let tree = repo.find_tree(tree_oid)?;
499 let sig = self.signature(repo)?;
500
501 let head_commit = repo.head()?.peel_to_commit()?;
502 let fetch_commit_obj = repo.find_commit(fetch_commit.id())?;
503
504 let new_commit_oid = repo.commit(
505 Some("HEAD"),
506 &sig,
507 &sig,
508 &format!("chore: merge origin/{}", branch),
509 &tree,
510 &[&head_commit, &fetch_commit_obj],
511 )?;
512
513 repo.cleanup_state()?;
514
515 let mut new_head = [0u8; 20];
516 new_head.copy_from_slice(new_commit_oid.as_bytes());
517
518 info!(
519 "pull: merge complete ({} conflicts auto-resolved)",
520 conflicts_resolved
521 );
522 Ok(PullResult::Merged {
523 conflicts_resolved,
524 old_head,
525 new_head,
526 })
527 }
528
529 pub async fn pull(
535 self: &Arc<Self>,
536 auth: &AuthProvider,
537 branch: &str,
538 ) -> Result<PullResult, MemoryError> {
539 let token_result = auth.resolve_token();
543 let arc = Arc::clone(self);
544 let branch = branch.to_string();
545
546 tokio::task::spawn_blocking(move || -> Result<PullResult, MemoryError> {
547 let repo = arc
548 .inner
549 .lock()
550 .expect("lock poisoned — prior panic corrupted state");
551
552 let mut remote = match repo.find_remote("origin") {
554 Ok(r) => r,
555 Err(e) if e.code() == ErrorCode::NotFound => {
556 warn!("pull: no origin remote configured — skipping (local-only mode)");
557 return Ok(PullResult::NoRemote);
558 }
559 Err(e) => return Err(MemoryError::Git(e)),
560 };
561
562 let token = token_result?;
564
565 let callbacks = build_auth_callbacks(token);
567 let mut fetch_opts = git2::FetchOptions::new();
568 fetch_opts.remote_callbacks(callbacks);
569 remote.fetch(&[&branch], Some(&mut fetch_opts), None)?;
570
571 let fetch_head = match repo.find_reference("FETCH_HEAD") {
573 Ok(r) => r,
574 Err(e) if e.code() == ErrorCode::NotFound => {
575 return Ok(PullResult::UpToDate);
577 }
578 Err(e)
579 if e.class() == git2::ErrorClass::Reference
580 && e.message().contains("corrupted") =>
581 {
582 info!("pull: FETCH_HEAD is empty or corrupted — treating as empty remote");
584 return Ok(PullResult::UpToDate);
585 }
586 Err(e) => return Err(MemoryError::Git(e)),
587 };
588 let fetch_commit = match repo.reference_to_annotated_commit(&fetch_head) {
589 Ok(c) => c,
590 Err(e) if e.class() == git2::ErrorClass::Reference => {
591 info!("pull: FETCH_HEAD not resolvable — treating as empty remote");
593 return Ok(PullResult::UpToDate);
594 }
595 Err(e) => return Err(MemoryError::Git(e)),
596 };
597
598 let (analysis, _preference) = repo.merge_analysis(&[&fetch_commit])?;
600
601 if analysis.is_up_to_date() {
602 info!("pull: already up to date");
603 return Ok(PullResult::UpToDate);
604 }
605
606 if analysis.is_fast_forward() {
607 return fast_forward(&repo, &fetch_commit, &branch);
608 }
609
610 arc.merge_with_remote(&repo, &fetch_commit, &branch)
611 })
612 .await
613 .map_err(|e| MemoryError::Join(e.to_string()))?
614 }
615
616 pub fn diff_changed_memories(
624 &self,
625 old_oid: [u8; 20],
626 new_oid: [u8; 20],
627 ) -> Result<ChangedMemories, MemoryError> {
628 let repo = self
629 .inner
630 .lock()
631 .expect("lock poisoned — prior panic corrupted state");
632
633 let new_git_oid = git2::Oid::from_bytes(&new_oid).map_err(MemoryError::Git)?;
634 let new_tree = repo.find_commit(new_git_oid)?.tree()?;
635
636 let diff = if old_oid == [0u8; 20] {
639 repo.diff_tree_to_tree(None, Some(&new_tree), None)?
640 } else {
641 let old_git_oid = git2::Oid::from_bytes(&old_oid).map_err(MemoryError::Git)?;
642 let old_tree = repo.find_commit(old_git_oid)?.tree()?;
643 repo.diff_tree_to_tree(Some(&old_tree), Some(&new_tree), None)?
644 };
645
646 let mut changes = ChangedMemories::default();
647
648 diff.foreach(
649 &mut |delta, _progress| {
650 use git2::Delta;
651
652 let path = match delta.new_file().path().or_else(|| delta.old_file().path()) {
653 Some(p) => p,
654 None => return true,
655 };
656
657 let path_str = match path.to_str() {
658 Some(s) => s,
659 None => return true,
660 };
661
662 if !path_str.ends_with(".md") {
664 return true;
665 }
666 if !path_str.starts_with("global/") && !path_str.starts_with("projects/") {
667 return true;
668 }
669
670 let qualified = &path_str[..path_str.len() - 3];
672
673 match delta.status() {
674 Delta::Added | Delta::Modified => {
675 changes.upserted.push(qualified.to_string());
676 }
677 Delta::Renamed | Delta::Copied => {
678 if matches!(delta.status(), Delta::Renamed) {
681 if let Some(old_path) = delta.old_file().path().and_then(|p| p.to_str())
682 {
683 if old_path.ends_with(".md")
684 && (old_path.starts_with("global/")
685 || old_path.starts_with("projects/"))
686 {
687 changes
688 .removed
689 .push(old_path[..old_path.len() - 3].to_string());
690 }
691 }
692 }
693 changes.upserted.push(qualified.to_string());
694 }
695 Delta::Deleted => {
696 changes.removed.push(qualified.to_string());
697 }
698 _ => {}
699 }
700
701 true
702 },
703 None,
704 None,
705 None,
706 )
707 .map_err(MemoryError::Git)?;
708
709 Ok(changes)
710 }
711
712 fn resolve_conflicts_by_recency(
722 &self,
723 repo: &Repository,
724 index: &mut git2::Index,
725 ) -> Result<usize, MemoryError> {
726 struct ConflictInfo {
728 path: PathBuf,
729 our_blob: Option<Vec<u8>>,
730 their_blob: Option<Vec<u8>>,
731 }
732
733 let mut conflicts_info: Vec<ConflictInfo> = Vec::new();
734
735 {
736 let conflicts = index.conflicts()?;
737 for conflict in conflicts {
738 let conflict = conflict?;
739
740 let path = conflict
741 .our
742 .as_ref()
743 .or(conflict.their.as_ref())
744 .and_then(|e| std::str::from_utf8(&e.path).ok())
745 .map(|s| self.root.join(s));
746
747 let path = match path {
748 Some(p) => p,
749 None => continue,
750 };
751
752 let our_blob = conflict
753 .our
754 .as_ref()
755 .and_then(|e| repo.find_blob(e.id).ok())
756 .map(|b| b.content().to_vec());
757
758 let their_blob = conflict
759 .their
760 .as_ref()
761 .and_then(|e| repo.find_blob(e.id).ok())
762 .map(|b| b.content().to_vec());
763
764 conflicts_info.push(ConflictInfo {
765 path,
766 our_blob,
767 their_blob,
768 });
769 }
770 }
771
772 let mut resolved = 0usize;
773
774 for info in conflicts_info {
775 let our_str = info
776 .our_blob
777 .as_deref()
778 .and_then(|b| std::str::from_utf8(b).ok())
779 .map(str::to_owned);
780 let their_str = info
781 .their_blob
782 .as_deref()
783 .and_then(|b| std::str::from_utf8(b).ok())
784 .map(str::to_owned);
785
786 let our_ts = our_str
787 .as_deref()
788 .and_then(|s| Memory::from_markdown(s).ok())
789 .map(|m| m.metadata.updated_at);
790 let their_ts = their_str
791 .as_deref()
792 .and_then(|s| Memory::from_markdown(s).ok())
793 .map(|m| m.metadata.updated_at);
794
795 let (chosen_bytes, label): (Vec<u8>, String) =
797 match (our_str.as_deref(), their_str.as_deref()) {
798 (Some(ours), Some(theirs)) => match (our_ts, their_ts) {
799 (Some(ot), Some(tt)) if tt > ot => (
800 theirs.as_bytes().to_vec(),
801 format!("theirs (updated_at: {})", tt),
802 ),
803 (Some(ot), _) => (
804 ours.as_bytes().to_vec(),
805 format!("ours (updated_at: {})", ot),
806 ),
807 _ => (
808 ours.as_bytes().to_vec(),
809 "ours (timestamp unparseable)".to_string(),
810 ),
811 },
812 (Some(ours), None) => (
813 ours.as_bytes().to_vec(),
814 "ours (theirs missing)".to_string(),
815 ),
816 (None, Some(theirs)) => (
817 theirs.as_bytes().to_vec(),
818 "theirs (ours missing)".to_string(),
819 ),
820 (None, None) => {
821 match (info.our_blob.as_deref(), info.their_blob.as_deref()) {
823 (Some(ours), _) => {
824 (ours.to_vec(), "ours (binary/non-UTF-8)".to_string())
825 }
826 (_, Some(theirs)) => {
827 (theirs.to_vec(), "theirs (binary/non-UTF-8)".to_string())
828 }
829 (None, None) => {
830 warn!(
833 "conflict at '{}': both sides missing — removing from index",
834 info.path.display()
835 );
836 let relative = info.path.strip_prefix(&self.root).map_err(|e| {
837 MemoryError::InvalidInput {
838 reason: format!(
839 "path strip error during conflict resolution: {}",
840 e
841 ),
842 }
843 })?;
844 index.conflict_remove(relative)?;
845 resolved += 1;
846 continue;
847 }
848 }
849 }
850 };
851
852 warn!(
853 "conflict resolved: {} — kept {}",
854 info.path.display(),
855 label
856 );
857
858 self.assert_within_root(&info.path)?;
862 if let Some(parent) = info.path.parent() {
863 std::fs::create_dir_all(parent)?;
864 }
865 self.write_memory_file(&info.path, &chosen_bytes)?;
866
867 let relative =
869 info.path
870 .strip_prefix(&self.root)
871 .map_err(|e| MemoryError::InvalidInput {
872 reason: format!("path strip error during conflict resolution: {}", e),
873 })?;
874 index.add_path(relative)?;
875
876 resolved += 1;
877 }
878
879 Ok(resolved)
880 }
881
882 fn signature<'r>(&self, repo: &'r Repository) -> Result<Signature<'r>, MemoryError> {
883 let sig = repo
885 .signature()
886 .or_else(|_| Signature::now("memory-mcp", "memory-mcp@local"))?;
887 Ok(sig)
888 }
889
890 fn git_add_and_commit(
892 &self,
893 repo: &Repository,
894 file_path: &Path,
895 message: &str,
896 ) -> Result<(), MemoryError> {
897 let relative =
898 file_path
899 .strip_prefix(&self.root)
900 .map_err(|e| MemoryError::InvalidInput {
901 reason: format!("path strip error: {}", e),
902 })?;
903
904 let mut index = repo.index()?;
905 index.add_path(relative)?;
906 index.write()?;
907
908 let tree_oid = index.write_tree()?;
909 let tree = repo.find_tree(tree_oid)?;
910 let sig = self.signature(repo)?;
911
912 match repo.head() {
913 Ok(head) => {
914 let parent_commit = head.peel_to_commit()?;
915 repo.commit(Some("HEAD"), &sig, &sig, message, &tree, &[&parent_commit])?;
916 }
917 Err(e) if e.code() == ErrorCode::UnbornBranch || e.code() == ErrorCode::NotFound => {
918 repo.commit(Some("HEAD"), &sig, &sig, message, &tree, &[])?;
920 }
921 Err(e) => return Err(MemoryError::Git(e)),
922 }
923
924 Ok(())
925 }
926
927 fn assert_within_root(&self, path: &Path) -> Result<(), MemoryError> {
930 let parent = path.parent().unwrap_or(path);
933 let filename = path.file_name().ok_or_else(|| MemoryError::InvalidInput {
934 reason: "path has no filename component".to_string(),
935 })?;
936
937 let canon_parent = {
940 let mut p = parent.to_path_buf();
941 let mut suffixes: Vec<std::ffi::OsString> = Vec::new();
942 loop {
943 match p.canonicalize() {
944 Ok(c) => {
945 let mut full = c;
946 for s in suffixes.into_iter().rev() {
947 full.push(s);
948 }
949 break full;
950 }
951 Err(_) => {
952 if let Some(name) = p.file_name() {
953 suffixes.push(name.to_os_string());
954 }
955 match p.parent() {
956 Some(par) => p = par.to_path_buf(),
957 None => {
958 return Err(MemoryError::InvalidInput {
959 reason: "cannot resolve any ancestor of path".into(),
960 });
961 }
962 }
963 }
964 }
965 }
966 };
967
968 let resolved = canon_parent.join(filename);
969
970 let canon_root = self
971 .root
972 .canonicalize()
973 .map_err(|e| MemoryError::InvalidInput {
974 reason: format!("cannot canonicalize repo root: {}", e),
975 })?;
976
977 if !resolved.starts_with(&canon_root) {
978 return Err(MemoryError::InvalidInput {
979 reason: format!(
980 "path '{}' escapes repository root '{}'",
981 resolved.display(),
982 canon_root.display()
983 ),
984 });
985 }
986
987 {
992 let mut probe = canon_root.clone();
993 let relative =
995 resolved
996 .strip_prefix(&canon_root)
997 .map_err(|e| MemoryError::InvalidInput {
998 reason: format!("path strip error: {}", e),
999 })?;
1000 for component in relative.components() {
1001 probe.push(component);
1002 if (probe.exists() || probe.symlink_metadata().is_ok())
1004 && probe
1005 .symlink_metadata()
1006 .map(|m| m.file_type().is_symlink())
1007 .unwrap_or(false)
1008 {
1009 return Err(MemoryError::InvalidInput {
1010 reason: format!(
1011 "path component '{}' is a symlink, which is not allowed",
1012 probe.display()
1013 ),
1014 });
1015 }
1016 }
1017 }
1018
1019 Ok(())
1020 }
1021
1022 fn write_memory_file(&self, path: &Path, data: &[u8]) -> Result<(), MemoryError> {
1027 #[cfg(unix)]
1028 {
1029 use std::io::Write as _;
1030 use std::os::unix::fs::OpenOptionsExt as _;
1031 let mut f = std::fs::OpenOptions::new()
1032 .write(true)
1033 .create(true)
1034 .truncate(true)
1035 .custom_flags(libc::O_NOFOLLOW)
1036 .open(path)?;
1037 f.write_all(data)?;
1038 Ok(())
1039 }
1040 #[cfg(not(unix))]
1041 {
1042 std::fs::write(path, data)?;
1043 Ok(())
1044 }
1045 }
1046
1047 fn read_memory_file(&self, path: &Path) -> Result<String, MemoryError> {
1052 #[cfg(unix)]
1053 {
1054 use std::io::Read as _;
1055 use std::os::unix::fs::OpenOptionsExt as _;
1056 let mut f = std::fs::OpenOptions::new()
1057 .read(true)
1058 .custom_flags(libc::O_NOFOLLOW)
1059 .open(path)?;
1060 let mut buf = String::new();
1061 f.read_to_string(&mut buf)?;
1062 Ok(buf)
1063 }
1064 #[cfg(not(unix))]
1065 {
1066 Ok(std::fs::read_to_string(path)?)
1067 }
1068 }
1069}
1070
1071#[cfg(test)]
1076mod tests {
1077 use super::*;
1078 use crate::auth::AuthProvider;
1079 use crate::types::{Memory, MemoryMetadata, PullResult, Scope};
1080 use std::sync::Arc;
1081
1082 fn test_auth() -> AuthProvider {
1083 AuthProvider::with_token("test-token-unused-for-file-remotes")
1084 }
1085
1086 fn make_memory(name: &str, content: &str, updated_at_secs: i64) -> Memory {
1087 let meta = MemoryMetadata {
1088 tags: vec![],
1089 scope: Scope::Global,
1090 created_at: chrono::DateTime::from_timestamp(1_700_000_000, 0).unwrap(),
1091 updated_at: chrono::DateTime::from_timestamp(updated_at_secs, 0).unwrap(),
1092 source: None,
1093 };
1094 Memory::new(name.to_string(), content.to_string(), meta)
1095 }
1096
1097 fn setup_bare_remote() -> (tempfile::TempDir, String) {
1098 let dir = tempfile::tempdir().expect("failed to create temp dir");
1099 git2::Repository::init_bare(dir.path()).expect("failed to init bare repo");
1100 let url = format!("file://{}", dir.path().display());
1101 (dir, url)
1102 }
1103
1104 fn open_repo(dir: &tempfile::TempDir, remote_url: Option<&str>) -> Arc<MemoryRepo> {
1105 Arc::new(MemoryRepo::init_or_open(dir.path(), remote_url).expect("failed to init repo"))
1106 }
1107
1108 #[test]
1111 fn redact_url_strips_userinfo() {
1112 assert_eq!(
1113 redact_url("https://user:ghp_token123@github.com/org/repo.git"),
1114 "https://[REDACTED]@github.com/org/repo.git"
1115 );
1116 }
1117
1118 #[test]
1119 fn redact_url_no_at_passthrough() {
1120 let url = "https://github.com/org/repo.git";
1121 assert_eq!(redact_url(url), url);
1122 }
1123
1124 #[test]
1125 fn redact_url_file_protocol_passthrough() {
1126 let url = "file:///tmp/bare.git";
1127 assert_eq!(redact_url(url), url);
1128 }
1129
1130 #[test]
1133 fn assert_within_root_accepts_valid_path() {
1134 let dir = tempfile::tempdir().unwrap();
1135 let repo = MemoryRepo::init_or_open(dir.path(), None).unwrap();
1136 let valid = dir.path().join("global").join("my-memory.md");
1137 std::fs::create_dir_all(valid.parent().unwrap()).unwrap();
1139 assert!(repo.assert_within_root(&valid).is_ok());
1140 }
1141
1142 #[test]
1143 fn assert_within_root_rejects_escape() {
1144 let dir = tempfile::tempdir().unwrap();
1145 let repo = MemoryRepo::init_or_open(dir.path(), None).unwrap();
1146 let _evil = dir
1149 .path()
1150 .join("..")
1151 .join("..")
1152 .join("..")
1153 .join("tmp")
1154 .join("evil.md");
1155 let outside = std::path::PathBuf::from("/tmp/definitely-outside");
1159 assert!(repo.assert_within_root(&outside).is_err());
1160 }
1161
1162 #[tokio::test]
1165 async fn push_local_only_returns_ok() {
1166 let dir = tempfile::tempdir().unwrap();
1167 let repo = open_repo(&dir, None);
1168 let auth = test_auth();
1169 let result = repo.push(&auth, "main").await;
1171 assert!(result.is_ok());
1172 }
1173
1174 #[tokio::test]
1175 async fn pull_local_only_returns_no_remote() {
1176 let dir = tempfile::tempdir().unwrap();
1177 let repo = open_repo(&dir, None);
1178 let auth = test_auth();
1179 let result = repo.pull(&auth, "main").await.unwrap();
1180 assert!(matches!(result, PullResult::NoRemote));
1181 }
1182
1183 #[tokio::test]
1186 async fn push_to_bare_remote() {
1187 let (_remote_dir, remote_url) = setup_bare_remote();
1188 let local_dir = tempfile::tempdir().unwrap();
1189 let repo = open_repo(&local_dir, Some(&remote_url));
1190 let auth = test_auth();
1191
1192 let mem = make_memory("test-push", "push content", 1_700_000_000);
1194 repo.save_memory(&mem).await.unwrap();
1195
1196 repo.push(&auth, "main").await.unwrap();
1198
1199 let bare = git2::Repository::open_bare(_remote_dir.path()).unwrap();
1201 let head = bare.find_reference("refs/heads/main").unwrap();
1202 let commit = head.peel_to_commit().unwrap();
1203 assert!(commit.message().unwrap().contains("test-push"));
1204 }
1205
1206 #[tokio::test]
1207 async fn pull_from_empty_bare_remote_returns_up_to_date() {
1208 let (_remote_dir, remote_url) = setup_bare_remote();
1209 let local_dir = tempfile::tempdir().unwrap();
1210 let repo = open_repo(&local_dir, Some(&remote_url));
1211 let auth = test_auth();
1212
1213 let mem = make_memory("seed", "seed content", 1_700_000_000);
1215 repo.save_memory(&mem).await.unwrap();
1216
1217 let result = repo.pull(&auth, "main").await.unwrap();
1219 assert!(matches!(result, PullResult::UpToDate));
1220 }
1221
1222 #[tokio::test]
1223 async fn pull_fast_forward() {
1224 let (_remote_dir, remote_url) = setup_bare_remote();
1225 let auth = test_auth();
1226
1227 let dir_a = tempfile::tempdir().unwrap();
1229 let repo_a = open_repo(&dir_a, Some(&remote_url));
1230 let mem = make_memory("from-a", "content from A", 1_700_000_000);
1231 repo_a.save_memory(&mem).await.unwrap();
1232 repo_a.push(&auth, "main").await.unwrap();
1233
1234 let dir_b = tempfile::tempdir().unwrap();
1236 let repo_b = open_repo(&dir_b, Some(&remote_url));
1237 let seed = make_memory("seed-b", "seed", 1_700_000_000);
1239 repo_b.save_memory(&seed).await.unwrap();
1240
1241 let result = repo_b.pull(&auth, "main").await.unwrap();
1242 assert!(
1243 matches!(
1244 result,
1245 PullResult::FastForward { .. } | PullResult::Merged { .. }
1246 ),
1247 "expected fast-forward or merge, got {:?}",
1248 result
1249 );
1250
1251 let file = dir_b.path().join("global").join("from-a.md");
1253 assert!(file.exists(), "from-a.md should exist in repo B after pull");
1254 }
1255
1256 #[tokio::test]
1257 async fn pull_up_to_date_after_push() {
1258 let (_remote_dir, remote_url) = setup_bare_remote();
1259 let local_dir = tempfile::tempdir().unwrap();
1260 let repo = open_repo(&local_dir, Some(&remote_url));
1261 let auth = test_auth();
1262
1263 let mem = make_memory("synced", "synced content", 1_700_000_000);
1264 repo.save_memory(&mem).await.unwrap();
1265 repo.push(&auth, "main").await.unwrap();
1266
1267 let result = repo.pull(&auth, "main").await.unwrap();
1269 assert!(matches!(result, PullResult::UpToDate));
1270 }
1271
1272 #[tokio::test]
1275 async fn pull_merge_conflict_theirs_newer_wins() {
1276 let (_remote_dir, remote_url) = setup_bare_remote();
1277 let auth = test_auth();
1278
1279 let dir_a = tempfile::tempdir().unwrap();
1281 let repo_a = open_repo(&dir_a, Some(&remote_url));
1282 let mem_a1 = make_memory("shared", "version from A initial", 1_700_000_100);
1283 repo_a.save_memory(&mem_a1).await.unwrap();
1284 repo_a.push(&auth, "main").await.unwrap();
1285
1286 let dir_b = tempfile::tempdir().unwrap();
1288 let repo_b = open_repo(&dir_b, Some(&remote_url));
1289 let seed = make_memory("seed-b", "seed", 1_700_000_000);
1290 repo_b.save_memory(&seed).await.unwrap();
1291 repo_b.pull(&auth, "main").await.unwrap();
1292
1293 let mem_b = make_memory("shared", "version from B (newer)", 1_700_000_300);
1294 repo_b.save_memory(&mem_b).await.unwrap();
1295 repo_b.push(&auth, "main").await.unwrap();
1296
1297 let mem_a2 = make_memory("shared", "version from A (older)", 1_700_000_200);
1299 repo_a.save_memory(&mem_a2).await.unwrap();
1300 let result = repo_a.pull(&auth, "main").await.unwrap();
1301
1302 assert!(
1303 matches!(result, PullResult::Merged { conflicts_resolved, .. } if conflicts_resolved >= 1),
1304 "expected merge with conflicts resolved, got {:?}",
1305 result
1306 );
1307
1308 let file = dir_a.path().join("global").join("shared.md");
1310 let content = std::fs::read_to_string(&file).unwrap();
1311 assert!(
1312 content.contains("version from B (newer)"),
1313 "expected B's version to win (newer timestamp), got: {}",
1314 content
1315 );
1316 }
1317
1318 #[tokio::test]
1319 async fn pull_merge_conflict_ours_newer_wins() {
1320 let (_remote_dir, remote_url) = setup_bare_remote();
1321 let auth = test_auth();
1322
1323 let dir_a = tempfile::tempdir().unwrap();
1325 let repo_a = open_repo(&dir_a, Some(&remote_url));
1326 let mem_a1 = make_memory("shared", "version from A initial", 1_700_000_100);
1327 repo_a.save_memory(&mem_a1).await.unwrap();
1328 repo_a.push(&auth, "main").await.unwrap();
1329
1330 let dir_b = tempfile::tempdir().unwrap();
1332 let repo_b = open_repo(&dir_b, Some(&remote_url));
1333 let seed = make_memory("seed-b", "seed", 1_700_000_000);
1334 repo_b.save_memory(&seed).await.unwrap();
1335 repo_b.pull(&auth, "main").await.unwrap();
1336
1337 let mem_b = make_memory("shared", "version from B (older)", 1_700_000_200);
1338 repo_b.save_memory(&mem_b).await.unwrap();
1339 repo_b.push(&auth, "main").await.unwrap();
1340
1341 let mem_a2 = make_memory("shared", "version from A (newer)", 1_700_000_300);
1343 repo_a.save_memory(&mem_a2).await.unwrap();
1344 let result = repo_a.pull(&auth, "main").await.unwrap();
1345
1346 assert!(
1347 matches!(result, PullResult::Merged { conflicts_resolved, .. } if conflicts_resolved >= 1),
1348 "expected merge with conflicts resolved, got {:?}",
1349 result
1350 );
1351
1352 let file = dir_a.path().join("global").join("shared.md");
1354 let content = std::fs::read_to_string(&file).unwrap();
1355 assert!(
1356 content.contains("version from A (newer)"),
1357 "expected A's version to win (newer timestamp), got: {}",
1358 content
1359 );
1360 }
1361
1362 #[tokio::test]
1363 async fn pull_merge_no_conflict_different_files() {
1364 let (_remote_dir, remote_url) = setup_bare_remote();
1365 let auth = test_auth();
1366
1367 let dir_a = tempfile::tempdir().unwrap();
1369 let repo_a = open_repo(&dir_a, Some(&remote_url));
1370 let mem_a = make_memory("mem-a", "from A", 1_700_000_100);
1371 repo_a.save_memory(&mem_a).await.unwrap();
1372 repo_a.push(&auth, "main").await.unwrap();
1373
1374 let dir_b = tempfile::tempdir().unwrap();
1376 let repo_b = open_repo(&dir_b, Some(&remote_url));
1377 let seed = make_memory("seed-b", "seed", 1_700_000_000);
1378 repo_b.save_memory(&seed).await.unwrap();
1379 repo_b.pull(&auth, "main").await.unwrap();
1380 let mem_b = make_memory("mem-b", "from B", 1_700_000_200);
1381 repo_b.save_memory(&mem_b).await.unwrap();
1382 repo_b.push(&auth, "main").await.unwrap();
1383
1384 let mem_a2 = make_memory("mem-a2", "also from A", 1_700_000_300);
1386 repo_a.save_memory(&mem_a2).await.unwrap();
1387 let result = repo_a.pull(&auth, "main").await.unwrap();
1388
1389 assert!(
1390 matches!(
1391 result,
1392 PullResult::Merged {
1393 conflicts_resolved: 0,
1394 ..
1395 }
1396 ),
1397 "expected clean merge, got {:?}",
1398 result
1399 );
1400
1401 assert!(dir_a.path().join("global").join("mem-b.md").exists());
1403 }
1404
1405 fn commit_file(repo: &Arc<MemoryRepo>, rel_path: &str, content: &str) -> [u8; 20] {
1409 let inner = repo.inner.lock().expect("lock poisoned");
1410 let full_path = repo.root.join(rel_path);
1411 if let Some(parent) = full_path.parent() {
1412 std::fs::create_dir_all(parent).unwrap();
1413 }
1414 std::fs::write(&full_path, content).unwrap();
1415
1416 let mut index = inner.index().unwrap();
1417 index.add_path(std::path::Path::new(rel_path)).unwrap();
1418 index.write().unwrap();
1419 let tree_oid = index.write_tree().unwrap();
1420 let tree = inner.find_tree(tree_oid).unwrap();
1421 let sig = git2::Signature::now("test", "test@test.com").unwrap();
1422
1423 let oid = match inner.head() {
1424 Ok(head) => {
1425 let parent = head.peel_to_commit().unwrap();
1426 inner
1427 .commit(Some("HEAD"), &sig, &sig, "test commit", &tree, &[&parent])
1428 .unwrap()
1429 }
1430 Err(_) => inner
1431 .commit(Some("HEAD"), &sig, &sig, "initial commit", &tree, &[])
1432 .unwrap(),
1433 };
1434
1435 let mut buf = [0u8; 20];
1436 buf.copy_from_slice(oid.as_bytes());
1437 buf
1438 }
1439
1440 #[test]
1441 fn diff_changed_memories_detects_added_global() {
1442 let dir = tempfile::tempdir().unwrap();
1443 let repo = open_repo(&dir, None);
1444
1445 let old_oid = {
1447 let inner = repo.inner.lock().unwrap();
1448 let head = inner.head().unwrap();
1449 let mut buf = [0u8; 20];
1450 buf.copy_from_slice(head.peel_to_commit().unwrap().id().as_bytes());
1451 buf
1452 };
1453
1454 let new_oid = commit_file(&repo, "global/my-note.md", "# content");
1455
1456 let changes = repo.diff_changed_memories(old_oid, new_oid).unwrap();
1457 assert_eq!(changes.upserted, vec!["global/my-note".to_string()]);
1458 assert!(changes.removed.is_empty());
1459 }
1460
1461 #[test]
1462 fn diff_changed_memories_detects_deleted() {
1463 let dir = tempfile::tempdir().unwrap();
1464 let repo = open_repo(&dir, None);
1465
1466 let first_oid = commit_file(&repo, "global/to-delete.md", "hello");
1467 let second_oid = {
1468 let inner = repo.inner.lock().unwrap();
1469 let full_path = dir.path().join("global/to-delete.md");
1470 std::fs::remove_file(&full_path).unwrap();
1471 let mut index = inner.index().unwrap();
1472 index
1473 .remove_path(std::path::Path::new("global/to-delete.md"))
1474 .unwrap();
1475 index.write().unwrap();
1476 let tree_oid = index.write_tree().unwrap();
1477 let tree = inner.find_tree(tree_oid).unwrap();
1478 let sig = git2::Signature::now("test", "test@test.com").unwrap();
1479 let parent = inner.head().unwrap().peel_to_commit().unwrap();
1480 let oid = inner
1481 .commit(Some("HEAD"), &sig, &sig, "delete file", &tree, &[&parent])
1482 .unwrap();
1483 let mut buf = [0u8; 20];
1484 buf.copy_from_slice(oid.as_bytes());
1485 buf
1486 };
1487
1488 let changes = repo.diff_changed_memories(first_oid, second_oid).unwrap();
1489 assert!(changes.upserted.is_empty());
1490 assert_eq!(changes.removed, vec!["global/to-delete".to_string()]);
1491 }
1492
1493 #[test]
1494 fn diff_changed_memories_ignores_non_md_files() {
1495 let dir = tempfile::tempdir().unwrap();
1496 let repo = open_repo(&dir, None);
1497
1498 let old_oid = {
1499 let inner = repo.inner.lock().unwrap();
1500 let mut buf = [0u8; 20];
1501 buf.copy_from_slice(
1502 inner
1503 .head()
1504 .unwrap()
1505 .peel_to_commit()
1506 .unwrap()
1507 .id()
1508 .as_bytes(),
1509 );
1510 buf
1511 };
1512
1513 let _ = commit_file(&repo, "global/config.json", "{}");
1515 let new_oid = commit_file(&repo, "other/note.md", "# ignored");
1516
1517 let changes = repo.diff_changed_memories(old_oid, new_oid).unwrap();
1518 assert!(
1519 changes.upserted.is_empty(),
1520 "should ignore non-.md and out-of-scope files"
1521 );
1522 assert!(changes.removed.is_empty());
1523 }
1524
1525 #[test]
1526 fn diff_changed_memories_detects_modified() {
1527 let dir = tempfile::tempdir().unwrap();
1528 let repo = open_repo(&dir, None);
1529
1530 let first_oid = commit_file(&repo, "projects/myproject/note.md", "version 1");
1531 let second_oid = commit_file(&repo, "projects/myproject/note.md", "version 2");
1532
1533 let changes = repo.diff_changed_memories(first_oid, second_oid).unwrap();
1534 assert_eq!(
1535 changes.upserted,
1536 vec!["projects/myproject/note".to_string()]
1537 );
1538 assert!(changes.removed.is_empty());
1539 }
1540
1541 #[test]
1544 fn diff_changed_memories_zero_oid_treats_all_as_added() {
1545 let dir = tempfile::tempdir().unwrap();
1546 let repo = open_repo(&dir, None);
1547
1548 let new_oid = commit_file(&repo, "global/first-memory.md", "# Hello");
1550
1551 let old_oid = [0u8; 20];
1553
1554 let changes = repo.diff_changed_memories(old_oid, new_oid).unwrap();
1555 assert_eq!(
1556 changes.upserted,
1557 vec!["global/first-memory".to_string()],
1558 "zero OID: all new-tree files should be additions"
1559 );
1560 assert!(changes.removed.is_empty(), "zero OID: no removals expected");
1561 }
1562}