1use std::{
2 path::{Path, PathBuf},
3 sync::{Arc, Mutex},
4};
5
6use git2::{build::CheckoutBuilder, ErrorCode, MergeOptions, Repository, Signature};
7use tracing::{info, warn};
8
9use secrecy::{ExposeSecret, SecretString};
10
11use crate::{
12 auth::AuthProvider,
13 error::MemoryError,
14 types::{validate_name, ChangedMemories, Memory, PullResult, Scope},
15};
16
17fn redact_url(url: &str) -> String {
25 if let Some(at_pos) = url.find('@') {
26 if let Some(scheme_end) = url.find("://") {
27 let scheme = &url[..scheme_end + 3];
28 let after_at = &url[at_pos + 1..];
29 return format!("{}[REDACTED]@{}", scheme, after_at);
30 }
31 }
32 url.to_string()
33}
34
35fn capture_head_oid(repo: &git2::Repository) -> Result<[u8; 20], MemoryError> {
39 match repo.head() {
40 Ok(h) => {
41 let oid = h.peel_to_commit()?.id();
42 let mut buf = [0u8; 20];
43 buf.copy_from_slice(oid.as_bytes());
44 Ok(buf)
45 }
46 Err(e) if e.code() == ErrorCode::UnbornBranch || e.code() == ErrorCode::NotFound => {
48 Ok([0u8; 20])
49 }
50 Err(e) => Err(MemoryError::Git(e)),
51 }
52}
53
54fn fast_forward(
59 repo: &git2::Repository,
60 fetch_commit: &git2::AnnotatedCommit,
61 branch: &str,
62) -> Result<PullResult, MemoryError> {
63 let old_head = capture_head_oid(repo)?;
64
65 let refname = format!("refs/heads/{branch}");
66 let target_oid = fetch_commit.id();
67
68 match repo.find_reference(&refname) {
69 Ok(mut reference) => {
70 reference.set_target(target_oid, &format!("pull: fast-forward to {}", target_oid))?;
71 }
72 Err(e) if e.code() == ErrorCode::NotFound => {
73 repo.reference(
75 &refname,
76 target_oid,
77 true,
78 &format!("pull: create branch {} from fetch", branch),
79 )?;
80 }
81 Err(e) => return Err(MemoryError::Git(e)),
82 }
83
84 repo.set_head(&refname)?;
85 let mut checkout = CheckoutBuilder::default();
86 checkout.force();
87 repo.checkout_head(Some(&mut checkout))?;
88
89 let mut new_head = [0u8; 20];
90 new_head.copy_from_slice(target_oid.as_bytes());
91
92 info!("pull: fast-forwarded to {}", target_oid);
93 Ok(PullResult::FastForward { old_head, new_head })
94}
95
96fn build_auth_callbacks(token: SecretString) -> git2::RemoteCallbacks<'static> {
100 let mut callbacks = git2::RemoteCallbacks::new();
101 callbacks.credentials(move |_url, _username, _allowed| {
102 git2::Cred::userpass_plaintext("x-access-token", token.expose_secret())
103 });
104 callbacks
105}
106
107pub struct MemoryRepo {
109 inner: Mutex<Repository>,
110 root: PathBuf,
111}
112
113unsafe impl Send for MemoryRepo {}
117unsafe impl Sync for MemoryRepo {}
118
119impl MemoryRepo {
120 pub fn init_or_open(path: &Path, remote_url: Option<&str>) -> Result<Self, MemoryError> {
125 let repo = if path.join(".git").exists() {
126 Repository::open(path)?
127 } else {
128 let mut opts = git2::RepositoryInitOptions::new();
129 opts.initial_head("main");
130 let repo = Repository::init_opts(path, &opts)?;
131 let gitignore = path.join(".gitignore");
133 if !gitignore.exists() {
134 std::fs::write(&gitignore, ".memory-mcp-index/\n")?;
135 }
136 {
138 let mut index = repo.index()?;
139 index.add_path(Path::new(".gitignore"))?;
140 index.write()?;
141 let tree_oid = index.write_tree()?;
142 let tree = repo.find_tree(tree_oid)?;
143 let sig = Signature::now("memory-mcp", "memory-mcp@local")?;
144 repo.commit(
145 Some("HEAD"),
146 &sig,
147 &sig,
148 "chore: init repository",
149 &tree,
150 &[],
151 )?;
152 }
153 repo
154 };
155
156 if let Some(url) = remote_url {
158 match repo.find_remote("origin") {
159 Ok(existing) => {
160 let current_url = existing.url().unwrap_or("");
162 if current_url != url {
163 repo.remote_set_url("origin", url)?;
164 info!("updated origin remote URL to {}", redact_url(url));
165 }
166 }
167 Err(e) if e.code() == ErrorCode::NotFound => {
168 repo.remote("origin", url)?;
169 info!("created origin remote pointing at {}", redact_url(url));
170 }
171 Err(e) => return Err(MemoryError::Git(e)),
172 }
173 }
174
175 Ok(Self {
176 inner: Mutex::new(repo),
177 root: path.to_path_buf(),
178 })
179 }
180
181 fn memory_path(&self, name: &str, scope: &Scope) -> PathBuf {
183 self.root
184 .join(scope.dir_prefix())
185 .join(format!("{}.md", name))
186 }
187
188 pub async fn save_memory(self: &Arc<Self>, memory: &Memory) -> Result<(), MemoryError> {
193 validate_name(&memory.name)?;
194 if let Scope::Project(ref project_name) = memory.metadata.scope {
195 validate_name(project_name)?;
196 }
197
198 let file_path = self.memory_path(&memory.name, &memory.metadata.scope);
199 self.assert_within_root(&file_path)?;
200
201 let arc = Arc::clone(self);
202 let memory = memory.clone();
203 tokio::task::spawn_blocking(move || -> Result<(), MemoryError> {
204 let repo = arc
205 .inner
206 .lock()
207 .expect("lock poisoned — prior panic corrupted state");
208
209 if let Some(parent) = file_path.parent() {
211 std::fs::create_dir_all(parent)?;
212 }
213
214 let markdown = memory.to_markdown()?;
215 arc.write_memory_file(&file_path, markdown.as_bytes())?;
216
217 arc.git_add_and_commit(
218 &repo,
219 &file_path,
220 &format!("chore: save memory '{}'", memory.name),
221 )?;
222 Ok(())
223 })
224 .await
225 .map_err(|e| MemoryError::Join(e.to_string()))?
226 }
227
228 pub async fn delete_memory(
230 self: &Arc<Self>,
231 name: &str,
232 scope: &Scope,
233 ) -> Result<(), MemoryError> {
234 validate_name(name)?;
235 if let Scope::Project(ref project_name) = *scope {
236 validate_name(project_name)?;
237 }
238
239 let file_path = self.memory_path(name, scope);
240 self.assert_within_root(&file_path)?;
241
242 let arc = Arc::clone(self);
243 let name = name.to_string();
244 let file_path_clone = file_path.clone();
245 tokio::task::spawn_blocking(move || -> Result<(), MemoryError> {
246 let repo = arc
247 .inner
248 .lock()
249 .expect("lock poisoned — prior panic corrupted state");
250
251 match std::fs::symlink_metadata(&file_path_clone) {
253 Err(_) => return Err(MemoryError::NotFound { name: name.clone() }),
254 Ok(m) if m.file_type().is_symlink() => {
255 return Err(MemoryError::InvalidInput {
256 reason: format!(
257 "path '{}' is a symlink, which is not permitted",
258 file_path_clone.display()
259 ),
260 });
261 }
262 Ok(_) => {}
263 }
264
265 std::fs::remove_file(&file_path_clone)?;
266 let relative =
268 file_path_clone
269 .strip_prefix(&arc.root)
270 .map_err(|e| MemoryError::InvalidInput {
271 reason: format!("path strip error: {}", e),
272 })?;
273 let mut index = repo.index()?;
274 index.remove_path(relative)?;
275 index.write()?;
276
277 let tree_oid = index.write_tree()?;
278 let tree = repo.find_tree(tree_oid)?;
279 let sig = arc.signature(&repo)?;
280 let message = format!("chore: delete memory '{}'", name);
281
282 match repo.head() {
283 Ok(head) => {
284 let parent_commit = head.peel_to_commit()?;
285 repo.commit(Some("HEAD"), &sig, &sig, &message, &tree, &[&parent_commit])?;
286 }
287 Err(e)
288 if e.code() == ErrorCode::UnbornBranch || e.code() == ErrorCode::NotFound =>
289 {
290 repo.commit(Some("HEAD"), &sig, &sig, &message, &tree, &[])?;
291 }
292 Err(e) => return Err(MemoryError::Git(e)),
293 }
294
295 Ok(())
296 })
297 .await
298 .map_err(|e| MemoryError::Join(e.to_string()))?
299 }
300
301 pub async fn read_memory(
303 self: &Arc<Self>,
304 name: &str,
305 scope: &Scope,
306 ) -> Result<Memory, MemoryError> {
307 validate_name(name)?;
308 if let Scope::Project(ref project_name) = *scope {
309 validate_name(project_name)?;
310 }
311
312 let file_path = self.memory_path(name, scope);
313 self.assert_within_root(&file_path)?;
314
315 let arc = Arc::clone(self);
316 let name = name.to_string();
317 tokio::task::spawn_blocking(move || -> Result<Memory, MemoryError> {
318 match std::fs::symlink_metadata(&file_path) {
320 Err(_) => return Err(MemoryError::NotFound { name }),
321 Ok(m) if m.file_type().is_symlink() => {
322 return Err(MemoryError::InvalidInput {
323 reason: format!(
324 "path '{}' is a symlink, which is not permitted",
325 file_path.display()
326 ),
327 });
328 }
329 Ok(_) => {}
330 }
331 let raw = arc.read_memory_file(&file_path)?;
332 Memory::from_markdown(&raw)
333 })
334 .await
335 .map_err(|e| MemoryError::Join(e.to_string()))?
336 }
337
338 pub async fn list_memories(
340 self: &Arc<Self>,
341 scope: Option<&Scope>,
342 ) -> Result<Vec<Memory>, MemoryError> {
343 let root = self.root.clone();
344 let scope_clone = scope.cloned();
345
346 tokio::task::spawn_blocking(move || -> Result<Vec<Memory>, MemoryError> {
347 let dirs: Vec<PathBuf> = match scope_clone.as_ref() {
348 Some(s) => vec![root.join(s.dir_prefix())],
349 None => {
350 let mut dirs = Vec::new();
352 let global = root.join("global");
353 if global.exists() {
354 dirs.push(global);
355 }
356 let projects = root.join("projects");
357 if projects.exists() {
358 for entry in std::fs::read_dir(&projects)? {
359 let entry = entry?;
360 if entry.file_type()?.is_dir() {
361 dirs.push(entry.path());
362 }
363 }
364 }
365 dirs
366 }
367 };
368
369 fn collect_md_files(dir: &Path, out: &mut Vec<Memory>) -> Result<(), MemoryError> {
370 if !dir.exists() {
371 return Ok(());
372 }
373 for entry in std::fs::read_dir(dir)? {
374 let entry = entry?;
375 let path = entry.path();
376 let ft = entry.file_type()?;
377 if ft.is_symlink() {
379 warn!(
380 "skipping symlink at {:?} — symlinks are not permitted in the memory store",
381 path
382 );
383 continue;
384 }
385 if ft.is_dir() {
386 collect_md_files(&path, out)?;
387 } else if path.extension().and_then(|e| e.to_str()) == Some("md") {
388 let raw = std::fs::read_to_string(&path)?;
389 match Memory::from_markdown(&raw) {
390 Ok(m) => out.push(m),
391 Err(e) => {
392 warn!("skipping {:?}: {}", path, e);
393 }
394 }
395 }
396 }
397 Ok(())
398 }
399
400 let mut memories = Vec::new();
401 for dir in dirs {
402 collect_md_files(&dir, &mut memories)?;
403 }
404
405 Ok(memories)
406 })
407 .await
408 .map_err(|e| MemoryError::Join(e.to_string()))?
409 }
410
411 pub async fn push(
416 self: &Arc<Self>,
417 auth: &AuthProvider,
418 branch: &str,
419 ) -> Result<(), MemoryError> {
420 let token_result = auth.resolve_token();
424 let arc = Arc::clone(self);
425 let branch = branch.to_string();
426
427 tokio::task::spawn_blocking(move || -> Result<(), MemoryError> {
428 let repo = arc
429 .inner
430 .lock()
431 .expect("lock poisoned — prior panic corrupted state");
432
433 let mut remote = match repo.find_remote("origin") {
434 Ok(r) => r,
435 Err(e) if e.code() == ErrorCode::NotFound => {
436 warn!("push: no origin remote configured — skipping (local-only mode)");
437 return Ok(());
438 }
439 Err(e) => return Err(MemoryError::Git(e)),
440 };
441
442 let token = token_result?;
444 let mut callbacks = build_auth_callbacks(token);
445
446 let rejections: Arc<Mutex<Vec<String>>> = Arc::new(Mutex::new(Vec::new()));
449 let rej = Arc::clone(&rejections);
450 callbacks.push_update_reference(move |refname, status| {
451 if let Some(msg) = status {
452 rej.lock()
453 .expect("rejection lock poisoned")
454 .push(format!("{refname}: {msg}"));
455 }
456 Ok(())
457 });
458
459 let mut push_opts = git2::PushOptions::new();
460 push_opts.remote_callbacks(callbacks);
461
462 let refspec = format!("refs/heads/{branch}:refs/heads/{branch}");
463 if let Err(e) = remote.push(&[&refspec], Some(&mut push_opts)) {
464 warn!("push to origin failed at transport level: {e}");
465 return Err(MemoryError::Git(e));
466 }
467
468 let rejected = rejections.lock().expect("rejection lock poisoned");
469 if !rejected.is_empty() {
470 return Err(MemoryError::PushRejected(rejected.join("; ")));
471 }
472
473 info!("pushed branch '{}' to origin", branch);
474 Ok(())
475 })
476 .await
477 .map_err(|e| MemoryError::Join(e.to_string()))?
478 }
479
480 fn merge_with_remote(
485 &self,
486 repo: &git2::Repository,
487 fetch_commit: &git2::AnnotatedCommit,
488 branch: &str,
489 ) -> Result<PullResult, MemoryError> {
490 let oid = repo.head()?.peel_to_commit()?.id();
494 let mut old_head = [0u8; 20];
495 old_head.copy_from_slice(oid.as_bytes());
496
497 let mut merge_opts = MergeOptions::new();
498 merge_opts.fail_on_conflict(false);
499 repo.merge(&[fetch_commit], Some(&mut merge_opts), None)?;
500
501 let mut index = repo.index()?;
502 let conflicts_resolved = if index.has_conflicts() {
503 self.resolve_conflicts_by_recency(repo, &mut index)?
504 } else {
505 0
506 };
507
508 if index.has_conflicts() {
512 let _ = repo.cleanup_state();
513 return Err(MemoryError::Internal(
514 "unresolved conflicts remain after auto-resolution".into(),
515 ));
516 }
517
518 index.write()?;
520 let tree_oid = index.write_tree()?;
521 let tree = repo.find_tree(tree_oid)?;
522 let sig = self.signature(repo)?;
523
524 let head_commit = repo.head()?.peel_to_commit()?;
525 let fetch_commit_obj = repo.find_commit(fetch_commit.id())?;
526
527 let new_commit_oid = repo.commit(
528 Some("HEAD"),
529 &sig,
530 &sig,
531 &format!("chore: merge origin/{}", branch),
532 &tree,
533 &[&head_commit, &fetch_commit_obj],
534 )?;
535
536 repo.cleanup_state()?;
537
538 let mut new_head = [0u8; 20];
539 new_head.copy_from_slice(new_commit_oid.as_bytes());
540
541 info!(
542 "pull: merge complete ({} conflicts auto-resolved)",
543 conflicts_resolved
544 );
545 Ok(PullResult::Merged {
546 conflicts_resolved,
547 old_head,
548 new_head,
549 })
550 }
551
552 pub async fn pull(
558 self: &Arc<Self>,
559 auth: &AuthProvider,
560 branch: &str,
561 ) -> Result<PullResult, MemoryError> {
562 let token_result = auth.resolve_token();
566 let arc = Arc::clone(self);
567 let branch = branch.to_string();
568
569 tokio::task::spawn_blocking(move || -> Result<PullResult, MemoryError> {
570 let repo = arc
571 .inner
572 .lock()
573 .expect("lock poisoned — prior panic corrupted state");
574
575 let mut remote = match repo.find_remote("origin") {
577 Ok(r) => r,
578 Err(e) if e.code() == ErrorCode::NotFound => {
579 warn!("pull: no origin remote configured — skipping (local-only mode)");
580 return Ok(PullResult::NoRemote);
581 }
582 Err(e) => return Err(MemoryError::Git(e)),
583 };
584
585 let token = token_result?;
587
588 let callbacks = build_auth_callbacks(token);
590 let mut fetch_opts = git2::FetchOptions::new();
591 fetch_opts.remote_callbacks(callbacks);
592 remote.fetch(&[&branch], Some(&mut fetch_opts), None)?;
593
594 let fetch_head = match repo.find_reference("FETCH_HEAD") {
596 Ok(r) => r,
597 Err(e) if e.code() == ErrorCode::NotFound => {
598 return Ok(PullResult::UpToDate);
600 }
601 Err(e)
602 if e.class() == git2::ErrorClass::Reference
603 && e.message().contains("corrupted") =>
604 {
605 info!("pull: FETCH_HEAD is empty or corrupted — treating as empty remote");
607 return Ok(PullResult::UpToDate);
608 }
609 Err(e) => return Err(MemoryError::Git(e)),
610 };
611 let fetch_commit = match repo.reference_to_annotated_commit(&fetch_head) {
612 Ok(c) => c,
613 Err(e) if e.class() == git2::ErrorClass::Reference => {
614 info!("pull: FETCH_HEAD not resolvable — treating as empty remote");
616 return Ok(PullResult::UpToDate);
617 }
618 Err(e) => return Err(MemoryError::Git(e)),
619 };
620
621 let (analysis, _preference) = repo.merge_analysis(&[&fetch_commit])?;
623
624 if analysis.is_up_to_date() {
625 info!("pull: already up to date");
626 return Ok(PullResult::UpToDate);
627 }
628
629 if analysis.is_fast_forward() {
630 return fast_forward(&repo, &fetch_commit, &branch);
631 }
632
633 arc.merge_with_remote(&repo, &fetch_commit, &branch)
634 })
635 .await
636 .map_err(|e| MemoryError::Join(e.to_string()))?
637 }
638
639 pub fn diff_changed_memories(
647 &self,
648 old_oid: [u8; 20],
649 new_oid: [u8; 20],
650 ) -> Result<ChangedMemories, MemoryError> {
651 let repo = self
652 .inner
653 .lock()
654 .expect("lock poisoned — prior panic corrupted state");
655
656 let new_git_oid = git2::Oid::from_bytes(&new_oid).map_err(MemoryError::Git)?;
657 let new_tree = repo.find_commit(new_git_oid)?.tree()?;
658
659 let diff = if old_oid == [0u8; 20] {
662 repo.diff_tree_to_tree(None, Some(&new_tree), None)?
663 } else {
664 let old_git_oid = git2::Oid::from_bytes(&old_oid).map_err(MemoryError::Git)?;
665 let old_tree = repo.find_commit(old_git_oid)?.tree()?;
666 repo.diff_tree_to_tree(Some(&old_tree), Some(&new_tree), None)?
667 };
668
669 let mut changes = ChangedMemories::default();
670
671 diff.foreach(
672 &mut |delta, _progress| {
673 use git2::Delta;
674
675 let path = match delta.new_file().path().or_else(|| delta.old_file().path()) {
676 Some(p) => p,
677 None => return true,
678 };
679
680 let path_str = match path.to_str() {
681 Some(s) => s,
682 None => return true,
683 };
684
685 if !path_str.ends_with(".md") {
687 return true;
688 }
689 if !path_str.starts_with("global/") && !path_str.starts_with("projects/") {
690 return true;
691 }
692
693 let qualified = &path_str[..path_str.len() - 3];
695
696 match delta.status() {
697 Delta::Added | Delta::Modified => {
698 changes.upserted.push(qualified.to_string());
699 }
700 Delta::Renamed | Delta::Copied => {
701 if matches!(delta.status(), Delta::Renamed) {
704 if let Some(old_path) = delta.old_file().path().and_then(|p| p.to_str())
705 {
706 if old_path.ends_with(".md")
707 && (old_path.starts_with("global/")
708 || old_path.starts_with("projects/"))
709 {
710 changes
711 .removed
712 .push(old_path[..old_path.len() - 3].to_string());
713 }
714 }
715 }
716 changes.upserted.push(qualified.to_string());
717 }
718 Delta::Deleted => {
719 changes.removed.push(qualified.to_string());
720 }
721 _ => {}
722 }
723
724 true
725 },
726 None,
727 None,
728 None,
729 )
730 .map_err(MemoryError::Git)?;
731
732 Ok(changes)
733 }
734
735 fn resolve_conflicts_by_recency(
745 &self,
746 repo: &Repository,
747 index: &mut git2::Index,
748 ) -> Result<usize, MemoryError> {
749 struct ConflictInfo {
751 path: PathBuf,
752 our_blob: Option<Vec<u8>>,
753 their_blob: Option<Vec<u8>>,
754 }
755
756 let mut conflicts_info: Vec<ConflictInfo> = Vec::new();
757
758 {
759 let conflicts = index.conflicts()?;
760 for conflict in conflicts {
761 let conflict = conflict?;
762
763 let path = conflict
764 .our
765 .as_ref()
766 .or(conflict.their.as_ref())
767 .and_then(|e| std::str::from_utf8(&e.path).ok())
768 .map(|s| self.root.join(s));
769
770 let path = match path {
771 Some(p) => p,
772 None => continue,
773 };
774
775 let our_blob = conflict
776 .our
777 .as_ref()
778 .and_then(|e| repo.find_blob(e.id).ok())
779 .map(|b| b.content().to_vec());
780
781 let their_blob = conflict
782 .their
783 .as_ref()
784 .and_then(|e| repo.find_blob(e.id).ok())
785 .map(|b| b.content().to_vec());
786
787 conflicts_info.push(ConflictInfo {
788 path,
789 our_blob,
790 their_blob,
791 });
792 }
793 }
794
795 let mut resolved = 0usize;
796
797 for info in conflicts_info {
798 let our_str = info
799 .our_blob
800 .as_deref()
801 .and_then(|b| std::str::from_utf8(b).ok())
802 .map(str::to_owned);
803 let their_str = info
804 .their_blob
805 .as_deref()
806 .and_then(|b| std::str::from_utf8(b).ok())
807 .map(str::to_owned);
808
809 let our_ts = our_str
810 .as_deref()
811 .and_then(|s| Memory::from_markdown(s).ok())
812 .map(|m| m.metadata.updated_at);
813 let their_ts = their_str
814 .as_deref()
815 .and_then(|s| Memory::from_markdown(s).ok())
816 .map(|m| m.metadata.updated_at);
817
818 let (chosen_bytes, label): (Vec<u8>, String) =
820 match (our_str.as_deref(), their_str.as_deref()) {
821 (Some(ours), Some(theirs)) => match (our_ts, their_ts) {
822 (Some(ot), Some(tt)) if tt > ot => (
823 theirs.as_bytes().to_vec(),
824 format!("theirs (updated_at: {})", tt),
825 ),
826 (Some(ot), _) => (
827 ours.as_bytes().to_vec(),
828 format!("ours (updated_at: {})", ot),
829 ),
830 _ => (
831 ours.as_bytes().to_vec(),
832 "ours (timestamp unparseable)".to_string(),
833 ),
834 },
835 (Some(ours), None) => (
836 ours.as_bytes().to_vec(),
837 "ours (theirs missing)".to_string(),
838 ),
839 (None, Some(theirs)) => (
840 theirs.as_bytes().to_vec(),
841 "theirs (ours missing)".to_string(),
842 ),
843 (None, None) => {
844 match (info.our_blob.as_deref(), info.their_blob.as_deref()) {
846 (Some(ours), _) => {
847 (ours.to_vec(), "ours (binary/non-UTF-8)".to_string())
848 }
849 (_, Some(theirs)) => {
850 (theirs.to_vec(), "theirs (binary/non-UTF-8)".to_string())
851 }
852 (None, None) => {
853 warn!(
856 "conflict at '{}': both sides missing — removing from index",
857 info.path.display()
858 );
859 let relative = info.path.strip_prefix(&self.root).map_err(|e| {
860 MemoryError::InvalidInput {
861 reason: format!(
862 "path strip error during conflict resolution: {}",
863 e
864 ),
865 }
866 })?;
867 index.conflict_remove(relative)?;
868 resolved += 1;
869 continue;
870 }
871 }
872 }
873 };
874
875 warn!(
876 "conflict resolved: {} — kept {}",
877 info.path.display(),
878 label
879 );
880
881 self.assert_within_root(&info.path)?;
885 if let Some(parent) = info.path.parent() {
886 std::fs::create_dir_all(parent)?;
887 }
888 self.write_memory_file(&info.path, &chosen_bytes)?;
889
890 let relative =
892 info.path
893 .strip_prefix(&self.root)
894 .map_err(|e| MemoryError::InvalidInput {
895 reason: format!("path strip error during conflict resolution: {}", e),
896 })?;
897 index.add_path(relative)?;
898
899 resolved += 1;
900 }
901
902 Ok(resolved)
903 }
904
905 fn signature<'r>(&self, repo: &'r Repository) -> Result<Signature<'r>, MemoryError> {
906 let sig = repo
908 .signature()
909 .or_else(|_| Signature::now("memory-mcp", "memory-mcp@local"))?;
910 Ok(sig)
911 }
912
913 fn git_add_and_commit(
915 &self,
916 repo: &Repository,
917 file_path: &Path,
918 message: &str,
919 ) -> Result<(), MemoryError> {
920 let relative =
921 file_path
922 .strip_prefix(&self.root)
923 .map_err(|e| MemoryError::InvalidInput {
924 reason: format!("path strip error: {}", e),
925 })?;
926
927 let mut index = repo.index()?;
928 index.add_path(relative)?;
929 index.write()?;
930
931 let tree_oid = index.write_tree()?;
932 let tree = repo.find_tree(tree_oid)?;
933 let sig = self.signature(repo)?;
934
935 match repo.head() {
936 Ok(head) => {
937 let parent_commit = head.peel_to_commit()?;
938 repo.commit(Some("HEAD"), &sig, &sig, message, &tree, &[&parent_commit])?;
939 }
940 Err(e) if e.code() == ErrorCode::UnbornBranch || e.code() == ErrorCode::NotFound => {
941 repo.commit(Some("HEAD"), &sig, &sig, message, &tree, &[])?;
943 }
944 Err(e) => return Err(MemoryError::Git(e)),
945 }
946
947 Ok(())
948 }
949
950 fn assert_within_root(&self, path: &Path) -> Result<(), MemoryError> {
953 let parent = path.parent().unwrap_or(path);
956 let filename = path.file_name().ok_or_else(|| MemoryError::InvalidInput {
957 reason: "path has no filename component".to_string(),
958 })?;
959
960 let canon_parent = {
963 let mut p = parent.to_path_buf();
964 let mut suffixes: Vec<std::ffi::OsString> = Vec::new();
965 loop {
966 match p.canonicalize() {
967 Ok(c) => {
968 let mut full = c;
969 for s in suffixes.into_iter().rev() {
970 full.push(s);
971 }
972 break full;
973 }
974 Err(_) => {
975 if let Some(name) = p.file_name() {
976 suffixes.push(name.to_os_string());
977 }
978 match p.parent() {
979 Some(par) => p = par.to_path_buf(),
980 None => {
981 return Err(MemoryError::InvalidInput {
982 reason: "cannot resolve any ancestor of path".into(),
983 });
984 }
985 }
986 }
987 }
988 }
989 };
990
991 let resolved = canon_parent.join(filename);
992
993 let canon_root = self
994 .root
995 .canonicalize()
996 .map_err(|e| MemoryError::InvalidInput {
997 reason: format!("cannot canonicalize repo root: {}", e),
998 })?;
999
1000 if !resolved.starts_with(&canon_root) {
1001 return Err(MemoryError::InvalidInput {
1002 reason: format!(
1003 "path '{}' escapes repository root '{}'",
1004 resolved.display(),
1005 canon_root.display()
1006 ),
1007 });
1008 }
1009
1010 {
1015 let mut probe = canon_root.clone();
1016 let relative =
1018 resolved
1019 .strip_prefix(&canon_root)
1020 .map_err(|e| MemoryError::InvalidInput {
1021 reason: format!("path strip error: {}", e),
1022 })?;
1023 for component in relative.components() {
1024 probe.push(component);
1025 if (probe.exists() || probe.symlink_metadata().is_ok())
1027 && probe
1028 .symlink_metadata()
1029 .map(|m| m.file_type().is_symlink())
1030 .unwrap_or(false)
1031 {
1032 return Err(MemoryError::InvalidInput {
1033 reason: format!(
1034 "path component '{}' is a symlink, which is not allowed",
1035 probe.display()
1036 ),
1037 });
1038 }
1039 }
1040 }
1041
1042 Ok(())
1043 }
1044
1045 fn write_memory_file(&self, path: &Path, data: &[u8]) -> Result<(), MemoryError> {
1056 if path
1058 .symlink_metadata()
1059 .map(|m| m.file_type().is_symlink())
1060 .unwrap_or(false)
1061 {
1062 return Err(MemoryError::InvalidInput {
1063 reason: format!("refusing to write through symlink: {}", path.display()),
1064 });
1065 }
1066
1067 #[cfg(unix)]
1071 {
1072 use std::os::unix::fs::OpenOptionsExt as _;
1073 if let Err(e) = std::fs::OpenOptions::new()
1074 .read(true)
1075 .custom_flags(libc::O_NOFOLLOW)
1076 .open(path)
1077 {
1078 if e.kind() != std::io::ErrorKind::NotFound {
1080 return Err(MemoryError::InvalidInput {
1081 reason: format!("O_NOFOLLOW check failed for {}: {e}", path.display()),
1082 });
1083 }
1084 }
1085 }
1086
1087 crate::fs_util::atomic_write(path, data)?;
1088 Ok(())
1089 }
1090
1091 fn read_memory_file(&self, path: &Path) -> Result<String, MemoryError> {
1096 #[cfg(unix)]
1097 {
1098 use std::io::Read as _;
1099 use std::os::unix::fs::OpenOptionsExt as _;
1100 let mut f = std::fs::OpenOptions::new()
1101 .read(true)
1102 .custom_flags(libc::O_NOFOLLOW)
1103 .open(path)?;
1104 let mut buf = String::new();
1105 f.read_to_string(&mut buf)?;
1106 Ok(buf)
1107 }
1108 #[cfg(not(unix))]
1109 {
1110 Ok(std::fs::read_to_string(path)?)
1111 }
1112 }
1113}
1114
1115#[cfg(test)]
1120mod tests {
1121 use super::*;
1122 use crate::auth::AuthProvider;
1123 use crate::types::{Memory, MemoryMetadata, PullResult, Scope};
1124 use std::sync::Arc;
1125
1126 fn test_auth() -> AuthProvider {
1127 AuthProvider::with_token("test-token-unused-for-file-remotes")
1128 }
1129
1130 fn make_memory(name: &str, content: &str, updated_at_secs: i64) -> Memory {
1131 let meta = MemoryMetadata {
1132 tags: vec![],
1133 scope: Scope::Global,
1134 created_at: chrono::DateTime::from_timestamp(1_700_000_000, 0).unwrap(),
1135 updated_at: chrono::DateTime::from_timestamp(updated_at_secs, 0).unwrap(),
1136 source: None,
1137 };
1138 Memory::new(name.to_string(), content.to_string(), meta)
1139 }
1140
1141 fn setup_bare_remote() -> (tempfile::TempDir, String) {
1142 let dir = tempfile::tempdir().expect("failed to create temp dir");
1143 git2::Repository::init_bare(dir.path()).expect("failed to init bare repo");
1144 let url = format!("file://{}", dir.path().display());
1145 (dir, url)
1146 }
1147
1148 fn open_repo(dir: &tempfile::TempDir, remote_url: Option<&str>) -> Arc<MemoryRepo> {
1149 Arc::new(MemoryRepo::init_or_open(dir.path(), remote_url).expect("failed to init repo"))
1150 }
1151
1152 #[test]
1155 fn redact_url_strips_userinfo() {
1156 assert_eq!(
1157 redact_url("https://user:ghp_token123@github.com/org/repo.git"),
1158 "https://[REDACTED]@github.com/org/repo.git"
1159 );
1160 }
1161
1162 #[test]
1163 fn redact_url_no_at_passthrough() {
1164 let url = "https://github.com/org/repo.git";
1165 assert_eq!(redact_url(url), url);
1166 }
1167
1168 #[test]
1169 fn redact_url_file_protocol_passthrough() {
1170 let url = "file:///tmp/bare.git";
1171 assert_eq!(redact_url(url), url);
1172 }
1173
1174 #[test]
1177 fn assert_within_root_accepts_valid_path() {
1178 let dir = tempfile::tempdir().unwrap();
1179 let repo = MemoryRepo::init_or_open(dir.path(), None).unwrap();
1180 let valid = dir.path().join("global").join("my-memory.md");
1181 std::fs::create_dir_all(valid.parent().unwrap()).unwrap();
1183 assert!(repo.assert_within_root(&valid).is_ok());
1184 }
1185
1186 #[test]
1187 fn assert_within_root_rejects_escape() {
1188 let dir = tempfile::tempdir().unwrap();
1189 let repo = MemoryRepo::init_or_open(dir.path(), None).unwrap();
1190 let _evil = dir
1193 .path()
1194 .join("..")
1195 .join("..")
1196 .join("..")
1197 .join("tmp")
1198 .join("evil.md");
1199 let outside = std::path::PathBuf::from("/tmp/definitely-outside");
1203 assert!(repo.assert_within_root(&outside).is_err());
1204 }
1205
1206 #[tokio::test]
1209 async fn push_local_only_returns_ok() {
1210 let dir = tempfile::tempdir().unwrap();
1211 let repo = open_repo(&dir, None);
1212 let auth = test_auth();
1213 let result = repo.push(&auth, "main").await;
1215 assert!(result.is_ok());
1216 }
1217
1218 #[tokio::test]
1219 async fn pull_local_only_returns_no_remote() {
1220 let dir = tempfile::tempdir().unwrap();
1221 let repo = open_repo(&dir, None);
1222 let auth = test_auth();
1223 let result = repo.pull(&auth, "main").await.unwrap();
1224 assert!(matches!(result, PullResult::NoRemote));
1225 }
1226
1227 #[tokio::test]
1230 async fn push_to_bare_remote() {
1231 let (_remote_dir, remote_url) = setup_bare_remote();
1232 let local_dir = tempfile::tempdir().unwrap();
1233 let repo = open_repo(&local_dir, Some(&remote_url));
1234 let auth = test_auth();
1235
1236 let mem = make_memory("test-push", "push content", 1_700_000_000);
1238 repo.save_memory(&mem).await.unwrap();
1239
1240 repo.push(&auth, "main").await.unwrap();
1242
1243 let bare = git2::Repository::open_bare(_remote_dir.path()).unwrap();
1245 let head = bare.find_reference("refs/heads/main").unwrap();
1246 let commit = head.peel_to_commit().unwrap();
1247 assert!(commit.message().unwrap().contains("test-push"));
1248 }
1249
1250 #[tokio::test]
1251 async fn pull_from_empty_bare_remote_returns_up_to_date() {
1252 let (_remote_dir, remote_url) = setup_bare_remote();
1253 let local_dir = tempfile::tempdir().unwrap();
1254 let repo = open_repo(&local_dir, Some(&remote_url));
1255 let auth = test_auth();
1256
1257 let mem = make_memory("seed", "seed content", 1_700_000_000);
1259 repo.save_memory(&mem).await.unwrap();
1260
1261 let result = repo.pull(&auth, "main").await.unwrap();
1263 assert!(matches!(result, PullResult::UpToDate));
1264 }
1265
1266 #[tokio::test]
1267 async fn pull_fast_forward() {
1268 let (_remote_dir, remote_url) = setup_bare_remote();
1269 let auth = test_auth();
1270
1271 let dir_a = tempfile::tempdir().unwrap();
1273 let repo_a = open_repo(&dir_a, Some(&remote_url));
1274 let mem = make_memory("from-a", "content from A", 1_700_000_000);
1275 repo_a.save_memory(&mem).await.unwrap();
1276 repo_a.push(&auth, "main").await.unwrap();
1277
1278 let dir_b = tempfile::tempdir().unwrap();
1280 let repo_b = open_repo(&dir_b, Some(&remote_url));
1281 let seed = make_memory("seed-b", "seed", 1_700_000_000);
1283 repo_b.save_memory(&seed).await.unwrap();
1284
1285 let result = repo_b.pull(&auth, "main").await.unwrap();
1286 assert!(
1287 matches!(
1288 result,
1289 PullResult::FastForward { .. } | PullResult::Merged { .. }
1290 ),
1291 "expected fast-forward or merge, got {:?}",
1292 result
1293 );
1294
1295 let file = dir_b.path().join("global").join("from-a.md");
1297 assert!(file.exists(), "from-a.md should exist in repo B after pull");
1298 }
1299
1300 #[tokio::test]
1301 async fn pull_up_to_date_after_push() {
1302 let (_remote_dir, remote_url) = setup_bare_remote();
1303 let local_dir = tempfile::tempdir().unwrap();
1304 let repo = open_repo(&local_dir, Some(&remote_url));
1305 let auth = test_auth();
1306
1307 let mem = make_memory("synced", "synced content", 1_700_000_000);
1308 repo.save_memory(&mem).await.unwrap();
1309 repo.push(&auth, "main").await.unwrap();
1310
1311 let result = repo.pull(&auth, "main").await.unwrap();
1313 assert!(matches!(result, PullResult::UpToDate));
1314 }
1315
1316 #[tokio::test]
1319 async fn pull_merge_conflict_theirs_newer_wins() {
1320 let (_remote_dir, remote_url) = setup_bare_remote();
1321 let auth = test_auth();
1322
1323 let dir_a = tempfile::tempdir().unwrap();
1325 let repo_a = open_repo(&dir_a, Some(&remote_url));
1326 let mem_a1 = make_memory("shared", "version from A initial", 1_700_000_100);
1327 repo_a.save_memory(&mem_a1).await.unwrap();
1328 repo_a.push(&auth, "main").await.unwrap();
1329
1330 let dir_b = tempfile::tempdir().unwrap();
1332 let repo_b = open_repo(&dir_b, Some(&remote_url));
1333 let seed = make_memory("seed-b", "seed", 1_700_000_000);
1334 repo_b.save_memory(&seed).await.unwrap();
1335 repo_b.pull(&auth, "main").await.unwrap();
1336
1337 let mem_b = make_memory("shared", "version from B (newer)", 1_700_000_300);
1338 repo_b.save_memory(&mem_b).await.unwrap();
1339 repo_b.push(&auth, "main").await.unwrap();
1340
1341 let mem_a2 = make_memory("shared", "version from A (older)", 1_700_000_200);
1343 repo_a.save_memory(&mem_a2).await.unwrap();
1344 let result = repo_a.pull(&auth, "main").await.unwrap();
1345
1346 assert!(
1347 matches!(result, PullResult::Merged { conflicts_resolved, .. } if conflicts_resolved >= 1),
1348 "expected merge with conflicts resolved, got {:?}",
1349 result
1350 );
1351
1352 let file = dir_a.path().join("global").join("shared.md");
1354 let content = std::fs::read_to_string(&file).unwrap();
1355 assert!(
1356 content.contains("version from B (newer)"),
1357 "expected B's version to win (newer timestamp), got: {}",
1358 content
1359 );
1360 }
1361
1362 #[tokio::test]
1363 async fn pull_merge_conflict_ours_newer_wins() {
1364 let (_remote_dir, remote_url) = setup_bare_remote();
1365 let auth = test_auth();
1366
1367 let dir_a = tempfile::tempdir().unwrap();
1369 let repo_a = open_repo(&dir_a, Some(&remote_url));
1370 let mem_a1 = make_memory("shared", "version from A initial", 1_700_000_100);
1371 repo_a.save_memory(&mem_a1).await.unwrap();
1372 repo_a.push(&auth, "main").await.unwrap();
1373
1374 let dir_b = tempfile::tempdir().unwrap();
1376 let repo_b = open_repo(&dir_b, Some(&remote_url));
1377 let seed = make_memory("seed-b", "seed", 1_700_000_000);
1378 repo_b.save_memory(&seed).await.unwrap();
1379 repo_b.pull(&auth, "main").await.unwrap();
1380
1381 let mem_b = make_memory("shared", "version from B (older)", 1_700_000_200);
1382 repo_b.save_memory(&mem_b).await.unwrap();
1383 repo_b.push(&auth, "main").await.unwrap();
1384
1385 let mem_a2 = make_memory("shared", "version from A (newer)", 1_700_000_300);
1387 repo_a.save_memory(&mem_a2).await.unwrap();
1388 let result = repo_a.pull(&auth, "main").await.unwrap();
1389
1390 assert!(
1391 matches!(result, PullResult::Merged { conflicts_resolved, .. } if conflicts_resolved >= 1),
1392 "expected merge with conflicts resolved, got {:?}",
1393 result
1394 );
1395
1396 let file = dir_a.path().join("global").join("shared.md");
1398 let content = std::fs::read_to_string(&file).unwrap();
1399 assert!(
1400 content.contains("version from A (newer)"),
1401 "expected A's version to win (newer timestamp), got: {}",
1402 content
1403 );
1404 }
1405
1406 #[tokio::test]
1407 async fn pull_merge_no_conflict_different_files() {
1408 let (_remote_dir, remote_url) = setup_bare_remote();
1409 let auth = test_auth();
1410
1411 let dir_a = tempfile::tempdir().unwrap();
1413 let repo_a = open_repo(&dir_a, Some(&remote_url));
1414 let mem_a = make_memory("mem-a", "from A", 1_700_000_100);
1415 repo_a.save_memory(&mem_a).await.unwrap();
1416 repo_a.push(&auth, "main").await.unwrap();
1417
1418 let dir_b = tempfile::tempdir().unwrap();
1420 let repo_b = open_repo(&dir_b, Some(&remote_url));
1421 let seed = make_memory("seed-b", "seed", 1_700_000_000);
1422 repo_b.save_memory(&seed).await.unwrap();
1423 repo_b.pull(&auth, "main").await.unwrap();
1424 let mem_b = make_memory("mem-b", "from B", 1_700_000_200);
1425 repo_b.save_memory(&mem_b).await.unwrap();
1426 repo_b.push(&auth, "main").await.unwrap();
1427
1428 let mem_a2 = make_memory("mem-a2", "also from A", 1_700_000_300);
1430 repo_a.save_memory(&mem_a2).await.unwrap();
1431 let result = repo_a.pull(&auth, "main").await.unwrap();
1432
1433 assert!(
1434 matches!(
1435 result,
1436 PullResult::Merged {
1437 conflicts_resolved: 0,
1438 ..
1439 }
1440 ),
1441 "expected clean merge, got {:?}",
1442 result
1443 );
1444
1445 assert!(dir_a.path().join("global").join("mem-b.md").exists());
1447 }
1448
1449 fn commit_file(repo: &Arc<MemoryRepo>, rel_path: &str, content: &str) -> [u8; 20] {
1453 let inner = repo.inner.lock().expect("lock poisoned");
1454 let full_path = repo.root.join(rel_path);
1455 if let Some(parent) = full_path.parent() {
1456 std::fs::create_dir_all(parent).unwrap();
1457 }
1458 std::fs::write(&full_path, content).unwrap();
1459
1460 let mut index = inner.index().unwrap();
1461 index.add_path(std::path::Path::new(rel_path)).unwrap();
1462 index.write().unwrap();
1463 let tree_oid = index.write_tree().unwrap();
1464 let tree = inner.find_tree(tree_oid).unwrap();
1465 let sig = git2::Signature::now("test", "test@test.com").unwrap();
1466
1467 let oid = match inner.head() {
1468 Ok(head) => {
1469 let parent = head.peel_to_commit().unwrap();
1470 inner
1471 .commit(Some("HEAD"), &sig, &sig, "test commit", &tree, &[&parent])
1472 .unwrap()
1473 }
1474 Err(_) => inner
1475 .commit(Some("HEAD"), &sig, &sig, "initial commit", &tree, &[])
1476 .unwrap(),
1477 };
1478
1479 let mut buf = [0u8; 20];
1480 buf.copy_from_slice(oid.as_bytes());
1481 buf
1482 }
1483
1484 #[test]
1485 fn diff_changed_memories_detects_added_global() {
1486 let dir = tempfile::tempdir().unwrap();
1487 let repo = open_repo(&dir, None);
1488
1489 let old_oid = {
1491 let inner = repo.inner.lock().unwrap();
1492 let head = inner.head().unwrap();
1493 let mut buf = [0u8; 20];
1494 buf.copy_from_slice(head.peel_to_commit().unwrap().id().as_bytes());
1495 buf
1496 };
1497
1498 let new_oid = commit_file(&repo, "global/my-note.md", "# content");
1499
1500 let changes = repo.diff_changed_memories(old_oid, new_oid).unwrap();
1501 assert_eq!(changes.upserted, vec!["global/my-note".to_string()]);
1502 assert!(changes.removed.is_empty());
1503 }
1504
1505 #[test]
1506 fn diff_changed_memories_detects_deleted() {
1507 let dir = tempfile::tempdir().unwrap();
1508 let repo = open_repo(&dir, None);
1509
1510 let first_oid = commit_file(&repo, "global/to-delete.md", "hello");
1511 let second_oid = {
1512 let inner = repo.inner.lock().unwrap();
1513 let full_path = dir.path().join("global/to-delete.md");
1514 std::fs::remove_file(&full_path).unwrap();
1515 let mut index = inner.index().unwrap();
1516 index
1517 .remove_path(std::path::Path::new("global/to-delete.md"))
1518 .unwrap();
1519 index.write().unwrap();
1520 let tree_oid = index.write_tree().unwrap();
1521 let tree = inner.find_tree(tree_oid).unwrap();
1522 let sig = git2::Signature::now("test", "test@test.com").unwrap();
1523 let parent = inner.head().unwrap().peel_to_commit().unwrap();
1524 let oid = inner
1525 .commit(Some("HEAD"), &sig, &sig, "delete file", &tree, &[&parent])
1526 .unwrap();
1527 let mut buf = [0u8; 20];
1528 buf.copy_from_slice(oid.as_bytes());
1529 buf
1530 };
1531
1532 let changes = repo.diff_changed_memories(first_oid, second_oid).unwrap();
1533 assert!(changes.upserted.is_empty());
1534 assert_eq!(changes.removed, vec!["global/to-delete".to_string()]);
1535 }
1536
1537 #[test]
1538 fn diff_changed_memories_ignores_non_md_files() {
1539 let dir = tempfile::tempdir().unwrap();
1540 let repo = open_repo(&dir, None);
1541
1542 let old_oid = {
1543 let inner = repo.inner.lock().unwrap();
1544 let mut buf = [0u8; 20];
1545 buf.copy_from_slice(
1546 inner
1547 .head()
1548 .unwrap()
1549 .peel_to_commit()
1550 .unwrap()
1551 .id()
1552 .as_bytes(),
1553 );
1554 buf
1555 };
1556
1557 let _ = commit_file(&repo, "global/config.json", "{}");
1559 let new_oid = commit_file(&repo, "other/note.md", "# ignored");
1560
1561 let changes = repo.diff_changed_memories(old_oid, new_oid).unwrap();
1562 assert!(
1563 changes.upserted.is_empty(),
1564 "should ignore non-.md and out-of-scope files"
1565 );
1566 assert!(changes.removed.is_empty());
1567 }
1568
1569 #[test]
1570 fn diff_changed_memories_detects_modified() {
1571 let dir = tempfile::tempdir().unwrap();
1572 let repo = open_repo(&dir, None);
1573
1574 let first_oid = commit_file(&repo, "projects/myproject/note.md", "version 1");
1575 let second_oid = commit_file(&repo, "projects/myproject/note.md", "version 2");
1576
1577 let changes = repo.diff_changed_memories(first_oid, second_oid).unwrap();
1578 assert_eq!(
1579 changes.upserted,
1580 vec!["projects/myproject/note".to_string()]
1581 );
1582 assert!(changes.removed.is_empty());
1583 }
1584
1585 #[test]
1588 fn diff_changed_memories_zero_oid_treats_all_as_added() {
1589 let dir = tempfile::tempdir().unwrap();
1590 let repo = open_repo(&dir, None);
1591
1592 let new_oid = commit_file(&repo, "global/first-memory.md", "# Hello");
1594
1595 let old_oid = [0u8; 20];
1597
1598 let changes = repo.diff_changed_memories(old_oid, new_oid).unwrap();
1599 assert_eq!(
1600 changes.upserted,
1601 vec!["global/first-memory".to_string()],
1602 "zero OID: all new-tree files should be additions"
1603 );
1604 assert!(changes.removed.is_empty(), "zero OID: no removals expected");
1605 }
1606}