1use std::{
2 path::{Path, PathBuf},
3 sync::{Arc, Mutex},
4};
5
6use git2::{build::CheckoutBuilder, ErrorCode, MergeOptions, Repository, Signature};
7use tracing::{info, warn};
8
9use secrecy::{ExposeSecret, SecretString};
10
11use crate::{
12 auth::AuthProvider,
13 error::MemoryError,
14 types::{validate_name, ChangedMemories, Memory, PullResult, Scope},
15};
16
17fn redact_url(url: &str) -> String {
25 if let Some(at_pos) = url.find('@') {
26 if let Some(scheme_end) = url.find("://") {
27 let scheme = &url[..scheme_end + 3];
28 let after_at = &url[at_pos + 1..];
29 return format!("{}[REDACTED]@{}", scheme, after_at);
30 }
31 }
32 url.to_string()
33}
34
35fn capture_head_oid(repo: &git2::Repository) -> Result<[u8; 20], MemoryError> {
39 match repo.head() {
40 Ok(h) => {
41 let oid = h.peel_to_commit()?.id();
42 let mut buf = [0u8; 20];
43 buf.copy_from_slice(oid.as_bytes());
44 Ok(buf)
45 }
46 Err(e) if e.code() == ErrorCode::UnbornBranch || e.code() == ErrorCode::NotFound => {
48 Ok([0u8; 20])
49 }
50 Err(e) => Err(MemoryError::Git(e)),
51 }
52}
53
54fn fast_forward(
59 repo: &git2::Repository,
60 fetch_commit: &git2::AnnotatedCommit,
61 branch: &str,
62) -> Result<PullResult, MemoryError> {
63 let old_head = capture_head_oid(repo)?;
64
65 let refname = format!("refs/heads/{branch}");
66 let target_oid = fetch_commit.id();
67
68 match repo.find_reference(&refname) {
69 Ok(mut reference) => {
70 reference.set_target(target_oid, &format!("pull: fast-forward to {}", target_oid))?;
71 }
72 Err(e) if e.code() == ErrorCode::NotFound => {
73 repo.reference(
75 &refname,
76 target_oid,
77 true,
78 &format!("pull: create branch {} from fetch", branch),
79 )?;
80 }
81 Err(e) => return Err(MemoryError::Git(e)),
82 }
83
84 repo.set_head(&refname)?;
85 let mut checkout = CheckoutBuilder::default();
86 checkout.force();
87 repo.checkout_head(Some(&mut checkout))?;
88
89 let mut new_head = [0u8; 20];
90 new_head.copy_from_slice(target_oid.as_bytes());
91
92 info!("pull: fast-forwarded to {}", target_oid);
93 Ok(PullResult::FastForward { old_head, new_head })
94}
95
96fn build_auth_callbacks(token: SecretString) -> git2::RemoteCallbacks<'static> {
100 let mut callbacks = git2::RemoteCallbacks::new();
101 callbacks.credentials(move |_url, _username, _allowed| {
102 git2::Cred::userpass_plaintext("x-access-token", token.expose_secret())
103 });
104 callbacks
105}
106
107pub struct MemoryRepo {
109 inner: Mutex<Repository>,
110 root: PathBuf,
111}
112
113unsafe impl Send for MemoryRepo {}
117unsafe impl Sync for MemoryRepo {}
118
119impl MemoryRepo {
120 pub fn init_or_open(path: &Path, remote_url: Option<&str>) -> Result<Self, MemoryError> {
125 let repo = if path.join(".git").exists() {
126 Repository::open(path)?
127 } else {
128 let mut opts = git2::RepositoryInitOptions::new();
129 opts.initial_head("main");
130 let repo = Repository::init_opts(path, &opts)?;
131 let gitignore = path.join(".gitignore");
133 if !gitignore.exists() {
134 std::fs::write(&gitignore, ".memory-mcp-index/\n")?;
135 }
136 {
138 let mut index = repo.index()?;
139 index.add_path(Path::new(".gitignore"))?;
140 index.write()?;
141 let tree_oid = index.write_tree()?;
142 let tree = repo.find_tree(tree_oid)?;
143 let sig = Signature::now("memory-mcp", "memory-mcp@local")?;
144 repo.commit(
145 Some("HEAD"),
146 &sig,
147 &sig,
148 "chore: init repository",
149 &tree,
150 &[],
151 )?;
152 }
153 repo
154 };
155
156 if let Some(url) = remote_url {
158 match repo.find_remote("origin") {
159 Ok(existing) => {
160 let current_url = existing.url().unwrap_or("");
162 if current_url != url {
163 repo.remote_set_url("origin", url)?;
164 info!("updated origin remote URL to {}", redact_url(url));
165 }
166 }
167 Err(e) if e.code() == ErrorCode::NotFound => {
168 repo.remote("origin", url)?;
169 info!("created origin remote pointing at {}", redact_url(url));
170 }
171 Err(e) => return Err(MemoryError::Git(e)),
172 }
173 }
174
175 Ok(Self {
176 inner: Mutex::new(repo),
177 root: path.to_path_buf(),
178 })
179 }
180
181 fn memory_path(&self, name: &str, scope: &Scope) -> PathBuf {
183 self.root
184 .join(scope.dir_prefix())
185 .join(format!("{}.md", name))
186 }
187
188 pub async fn save_memory(self: &Arc<Self>, memory: &Memory) -> Result<(), MemoryError> {
193 validate_name(&memory.name)?;
194 if let Scope::Project(ref project_name) = memory.metadata.scope {
195 validate_name(project_name)?;
196 }
197
198 let file_path = self.memory_path(&memory.name, &memory.metadata.scope);
199 self.assert_within_root(&file_path)?;
200
201 let arc = Arc::clone(self);
202 let memory = memory.clone();
203 tokio::task::spawn_blocking(move || -> Result<(), MemoryError> {
204 let repo = arc
205 .inner
206 .lock()
207 .expect("lock poisoned — prior panic corrupted state");
208
209 if let Some(parent) = file_path.parent() {
211 std::fs::create_dir_all(parent)?;
212 }
213
214 let markdown = memory.to_markdown()?;
215 arc.write_memory_file(&file_path, markdown.as_bytes())?;
216
217 arc.git_add_and_commit(
218 &repo,
219 &file_path,
220 &format!("chore: save memory '{}'", memory.name),
221 )?;
222 Ok(())
223 })
224 .await
225 .map_err(|e| MemoryError::Join(e.to_string()))?
226 }
227
228 pub async fn delete_memory(
230 self: &Arc<Self>,
231 name: &str,
232 scope: &Scope,
233 ) -> Result<(), MemoryError> {
234 validate_name(name)?;
235 if let Scope::Project(ref project_name) = *scope {
236 validate_name(project_name)?;
237 }
238
239 let file_path = self.memory_path(name, scope);
240 self.assert_within_root(&file_path)?;
241
242 let arc = Arc::clone(self);
243 let name = name.to_string();
244 let file_path_clone = file_path.clone();
245 tokio::task::spawn_blocking(move || -> Result<(), MemoryError> {
246 let repo = arc
247 .inner
248 .lock()
249 .expect("lock poisoned — prior panic corrupted state");
250
251 match std::fs::symlink_metadata(&file_path_clone) {
253 Err(_) => return Err(MemoryError::NotFound { name: name.clone() }),
254 Ok(m) if m.file_type().is_symlink() => {
255 return Err(MemoryError::InvalidInput {
256 reason: format!(
257 "path '{}' is a symlink, which is not permitted",
258 file_path_clone.display()
259 ),
260 });
261 }
262 Ok(_) => {}
263 }
264
265 std::fs::remove_file(&file_path_clone)?;
266 let relative =
268 file_path_clone
269 .strip_prefix(&arc.root)
270 .map_err(|e| MemoryError::InvalidInput {
271 reason: format!("path strip error: {}", e),
272 })?;
273 let mut index = repo.index()?;
274 index.remove_path(relative)?;
275 index.write()?;
276
277 let tree_oid = index.write_tree()?;
278 let tree = repo.find_tree(tree_oid)?;
279 let sig = arc.signature(&repo)?;
280 let message = format!("chore: delete memory '{}'", name);
281
282 match repo.head() {
283 Ok(head) => {
284 let parent_commit = head.peel_to_commit()?;
285 repo.commit(Some("HEAD"), &sig, &sig, &message, &tree, &[&parent_commit])?;
286 }
287 Err(e)
288 if e.code() == ErrorCode::UnbornBranch || e.code() == ErrorCode::NotFound =>
289 {
290 repo.commit(Some("HEAD"), &sig, &sig, &message, &tree, &[])?;
291 }
292 Err(e) => return Err(MemoryError::Git(e)),
293 }
294
295 Ok(())
296 })
297 .await
298 .map_err(|e| MemoryError::Join(e.to_string()))?
299 }
300
301 pub async fn read_memory(
303 self: &Arc<Self>,
304 name: &str,
305 scope: &Scope,
306 ) -> Result<Memory, MemoryError> {
307 validate_name(name)?;
308 if let Scope::Project(ref project_name) = *scope {
309 validate_name(project_name)?;
310 }
311
312 let file_path = self.memory_path(name, scope);
313 self.assert_within_root(&file_path)?;
314
315 let arc = Arc::clone(self);
316 let name = name.to_string();
317 tokio::task::spawn_blocking(move || -> Result<Memory, MemoryError> {
318 match std::fs::symlink_metadata(&file_path) {
320 Err(_) => return Err(MemoryError::NotFound { name }),
321 Ok(m) if m.file_type().is_symlink() => {
322 return Err(MemoryError::InvalidInput {
323 reason: format!(
324 "path '{}' is a symlink, which is not permitted",
325 file_path.display()
326 ),
327 });
328 }
329 Ok(_) => {}
330 }
331 let raw = arc.read_memory_file(&file_path)?;
332 Memory::from_markdown(&raw)
333 })
334 .await
335 .map_err(|e| MemoryError::Join(e.to_string()))?
336 }
337
338 pub async fn list_memories(
340 self: &Arc<Self>,
341 scope: Option<&Scope>,
342 ) -> Result<Vec<Memory>, MemoryError> {
343 let root = self.root.clone();
344 let scope_clone = scope.cloned();
345
346 tokio::task::spawn_blocking(move || -> Result<Vec<Memory>, MemoryError> {
347 let dirs: Vec<PathBuf> = match scope_clone.as_ref() {
348 Some(s) => vec![root.join(s.dir_prefix())],
349 None => {
350 let mut dirs = Vec::new();
352 let global = root.join("global");
353 if global.exists() {
354 dirs.push(global);
355 }
356 let projects = root.join("projects");
357 if projects.exists() {
358 for entry in std::fs::read_dir(&projects)? {
359 let entry = entry?;
360 if entry.file_type()?.is_dir() {
361 dirs.push(entry.path());
362 }
363 }
364 }
365 dirs
366 }
367 };
368
369 fn collect_md_files(dir: &Path, out: &mut Vec<Memory>) -> Result<(), MemoryError> {
370 if !dir.exists() {
371 return Ok(());
372 }
373 for entry in std::fs::read_dir(dir)? {
374 let entry = entry?;
375 let path = entry.path();
376 let ft = entry.file_type()?;
377 if ft.is_symlink() {
379 warn!(
380 "skipping symlink at {:?} — symlinks are not permitted in the memory store",
381 path
382 );
383 continue;
384 }
385 if ft.is_dir() {
386 collect_md_files(&path, out)?;
387 } else if path.extension().and_then(|e| e.to_str()) == Some("md") {
388 let raw = std::fs::read_to_string(&path)?;
389 match Memory::from_markdown(&raw) {
390 Ok(m) => out.push(m),
391 Err(e) => {
392 warn!("skipping {:?}: {}", path, e);
393 }
394 }
395 }
396 }
397 Ok(())
398 }
399
400 let mut memories = Vec::new();
401 for dir in dirs {
402 collect_md_files(&dir, &mut memories)?;
403 }
404
405 Ok(memories)
406 })
407 .await
408 .map_err(|e| MemoryError::Join(e.to_string()))?
409 }
410
411 pub async fn push(
416 self: &Arc<Self>,
417 auth: &AuthProvider,
418 branch: &str,
419 ) -> Result<(), MemoryError> {
420 let token_result = auth.resolve_token();
424 let arc = Arc::clone(self);
425 let branch = branch.to_string();
426
427 tokio::task::spawn_blocking(move || -> Result<(), MemoryError> {
428 let repo = arc
429 .inner
430 .lock()
431 .expect("lock poisoned — prior panic corrupted state");
432
433 let mut remote = match repo.find_remote("origin") {
434 Ok(r) => r,
435 Err(e) if e.code() == ErrorCode::NotFound => {
436 warn!("push: no origin remote configured — skipping (local-only mode)");
437 return Ok(());
438 }
439 Err(e) => return Err(MemoryError::Git(e)),
440 };
441
442 let token = token_result?;
444 let mut callbacks = build_auth_callbacks(token);
445
446 let rejections: Arc<Mutex<Vec<String>>> = Arc::new(Mutex::new(Vec::new()));
449 let rej = Arc::clone(&rejections);
450 callbacks.push_update_reference(move |refname, status| {
451 if let Some(msg) = status {
452 rej.lock()
453 .expect("rejection lock poisoned")
454 .push(format!("{refname}: {msg}"));
455 }
456 Ok(())
457 });
458
459 let mut push_opts = git2::PushOptions::new();
460 push_opts.remote_callbacks(callbacks);
461
462 let refspec = format!("refs/heads/{branch}:refs/heads/{branch}");
463 if let Err(e) = remote.push(&[&refspec], Some(&mut push_opts)) {
464 warn!("push to origin failed at transport level: {e}");
465 return Err(MemoryError::Git(e));
466 }
467
468 let rejected = rejections.lock().expect("rejection lock poisoned");
469 if !rejected.is_empty() {
470 return Err(MemoryError::PushRejected(rejected.join("; ")));
471 }
472
473 info!("pushed branch '{}' to origin", branch);
474 Ok(())
475 })
476 .await
477 .map_err(|e| MemoryError::Join(e.to_string()))?
478 }
479
480 fn merge_with_remote(
485 &self,
486 repo: &git2::Repository,
487 fetch_commit: &git2::AnnotatedCommit,
488 branch: &str,
489 ) -> Result<PullResult, MemoryError> {
490 let oid = repo.head()?.peel_to_commit()?.id();
494 let mut old_head = [0u8; 20];
495 old_head.copy_from_slice(oid.as_bytes());
496
497 let mut merge_opts = MergeOptions::new();
498 merge_opts.fail_on_conflict(false);
499 repo.merge(&[fetch_commit], Some(&mut merge_opts), None)?;
500
501 let mut index = repo.index()?;
502 let conflicts_resolved = if index.has_conflicts() {
503 self.resolve_conflicts_by_recency(repo, &mut index)?
504 } else {
505 0
506 };
507
508 if index.has_conflicts() {
512 let _ = repo.cleanup_state();
513 return Err(MemoryError::Internal(
514 "unresolved conflicts remain after auto-resolution".into(),
515 ));
516 }
517
518 index.write()?;
520 let tree_oid = index.write_tree()?;
521 let tree = repo.find_tree(tree_oid)?;
522 let sig = self.signature(repo)?;
523
524 let head_commit = repo.head()?.peel_to_commit()?;
525 let fetch_commit_obj = repo.find_commit(fetch_commit.id())?;
526
527 let new_commit_oid = repo.commit(
528 Some("HEAD"),
529 &sig,
530 &sig,
531 &format!("chore: merge origin/{}", branch),
532 &tree,
533 &[&head_commit, &fetch_commit_obj],
534 )?;
535
536 repo.cleanup_state()?;
537
538 let mut new_head = [0u8; 20];
539 new_head.copy_from_slice(new_commit_oid.as_bytes());
540
541 info!(
542 "pull: merge complete ({} conflicts auto-resolved)",
543 conflicts_resolved
544 );
545 Ok(PullResult::Merged {
546 conflicts_resolved,
547 old_head,
548 new_head,
549 })
550 }
551
552 pub async fn pull(
558 self: &Arc<Self>,
559 auth: &AuthProvider,
560 branch: &str,
561 ) -> Result<PullResult, MemoryError> {
562 let token_result = auth.resolve_token();
566 let arc = Arc::clone(self);
567 let branch = branch.to_string();
568
569 tokio::task::spawn_blocking(move || -> Result<PullResult, MemoryError> {
570 let repo = arc
571 .inner
572 .lock()
573 .expect("lock poisoned — prior panic corrupted state");
574
575 let mut remote = match repo.find_remote("origin") {
577 Ok(r) => r,
578 Err(e) if e.code() == ErrorCode::NotFound => {
579 warn!("pull: no origin remote configured — skipping (local-only mode)");
580 return Ok(PullResult::NoRemote);
581 }
582 Err(e) => return Err(MemoryError::Git(e)),
583 };
584
585 let token = token_result?;
587
588 let callbacks = build_auth_callbacks(token);
590 let mut fetch_opts = git2::FetchOptions::new();
591 fetch_opts.remote_callbacks(callbacks);
592 remote.fetch(&[&branch], Some(&mut fetch_opts), None)?;
593
594 let fetch_head = match repo.find_reference("FETCH_HEAD") {
596 Ok(r) => r,
597 Err(e) if e.code() == ErrorCode::NotFound => {
598 return Ok(PullResult::UpToDate);
600 }
601 Err(e)
602 if e.class() == git2::ErrorClass::Reference
603 && e.message().contains("corrupted") =>
604 {
605 info!("pull: FETCH_HEAD is empty or corrupted — treating as empty remote");
607 return Ok(PullResult::UpToDate);
608 }
609 Err(e) => return Err(MemoryError::Git(e)),
610 };
611 let fetch_commit = match repo.reference_to_annotated_commit(&fetch_head) {
612 Ok(c) => c,
613 Err(e) if e.class() == git2::ErrorClass::Reference => {
614 info!("pull: FETCH_HEAD not resolvable — treating as empty remote");
616 return Ok(PullResult::UpToDate);
617 }
618 Err(e) => return Err(MemoryError::Git(e)),
619 };
620
621 let (analysis, _preference) = repo.merge_analysis(&[&fetch_commit])?;
623
624 if analysis.is_up_to_date() {
625 info!("pull: already up to date");
626 return Ok(PullResult::UpToDate);
627 }
628
629 if analysis.is_fast_forward() {
630 return fast_forward(&repo, &fetch_commit, &branch);
631 }
632
633 arc.merge_with_remote(&repo, &fetch_commit, &branch)
634 })
635 .await
636 .map_err(|e| MemoryError::Join(e.to_string()))?
637 }
638
639 pub fn diff_changed_memories(
647 &self,
648 old_oid: [u8; 20],
649 new_oid: [u8; 20],
650 ) -> Result<ChangedMemories, MemoryError> {
651 let repo = self
652 .inner
653 .lock()
654 .expect("lock poisoned — prior panic corrupted state");
655
656 let new_git_oid = git2::Oid::from_bytes(&new_oid).map_err(MemoryError::Git)?;
657 let new_tree = repo.find_commit(new_git_oid)?.tree()?;
658
659 let diff = if old_oid == [0u8; 20] {
662 repo.diff_tree_to_tree(None, Some(&new_tree), None)?
663 } else {
664 let old_git_oid = git2::Oid::from_bytes(&old_oid).map_err(MemoryError::Git)?;
665 let old_tree = repo.find_commit(old_git_oid)?.tree()?;
666 repo.diff_tree_to_tree(Some(&old_tree), Some(&new_tree), None)?
667 };
668
669 let mut changes = ChangedMemories::default();
670
671 diff.foreach(
672 &mut |delta, _progress| {
673 use git2::Delta;
674
675 let path = match delta.new_file().path().or_else(|| delta.old_file().path()) {
676 Some(p) => p,
677 None => return true,
678 };
679
680 let path_str = match path.to_str() {
681 Some(s) => s,
682 None => return true,
683 };
684
685 if !path_str.ends_with(".md") {
687 return true;
688 }
689 if !path_str.starts_with("global/") && !path_str.starts_with("projects/") {
690 return true;
691 }
692
693 let qualified = &path_str[..path_str.len() - 3];
695
696 match delta.status() {
697 Delta::Added | Delta::Modified => {
698 changes.upserted.push(qualified.to_string());
699 }
700 Delta::Renamed | Delta::Copied => {
701 if matches!(delta.status(), Delta::Renamed) {
704 if let Some(old_path) = delta.old_file().path().and_then(|p| p.to_str())
705 {
706 if old_path.ends_with(".md")
707 && (old_path.starts_with("global/")
708 || old_path.starts_with("projects/"))
709 {
710 changes
711 .removed
712 .push(old_path[..old_path.len() - 3].to_string());
713 }
714 }
715 }
716 changes.upserted.push(qualified.to_string());
717 }
718 Delta::Deleted => {
719 changes.removed.push(qualified.to_string());
720 }
721 _ => {}
722 }
723
724 true
725 },
726 None,
727 None,
728 None,
729 )
730 .map_err(MemoryError::Git)?;
731
732 Ok(changes)
733 }
734
735 fn resolve_conflicts_by_recency(
745 &self,
746 repo: &Repository,
747 index: &mut git2::Index,
748 ) -> Result<usize, MemoryError> {
749 struct ConflictInfo {
751 path: PathBuf,
752 our_blob: Option<Vec<u8>>,
753 their_blob: Option<Vec<u8>>,
754 }
755
756 let mut conflicts_info: Vec<ConflictInfo> = Vec::new();
757
758 {
759 let conflicts = index.conflicts()?;
760 for conflict in conflicts {
761 let conflict = conflict?;
762
763 let path = conflict
764 .our
765 .as_ref()
766 .or(conflict.their.as_ref())
767 .and_then(|e| std::str::from_utf8(&e.path).ok())
768 .map(|s| self.root.join(s));
769
770 let path = match path {
771 Some(p) => p,
772 None => continue,
773 };
774
775 let our_blob = conflict
776 .our
777 .as_ref()
778 .and_then(|e| repo.find_blob(e.id).ok())
779 .map(|b| b.content().to_vec());
780
781 let their_blob = conflict
782 .their
783 .as_ref()
784 .and_then(|e| repo.find_blob(e.id).ok())
785 .map(|b| b.content().to_vec());
786
787 conflicts_info.push(ConflictInfo {
788 path,
789 our_blob,
790 their_blob,
791 });
792 }
793 }
794
795 let mut resolved = 0usize;
796
797 for info in conflicts_info {
798 let our_str = info
799 .our_blob
800 .as_deref()
801 .and_then(|b| std::str::from_utf8(b).ok())
802 .map(str::to_owned);
803 let their_str = info
804 .their_blob
805 .as_deref()
806 .and_then(|b| std::str::from_utf8(b).ok())
807 .map(str::to_owned);
808
809 let our_ts = our_str
810 .as_deref()
811 .and_then(|s| Memory::from_markdown(s).ok())
812 .map(|m| m.metadata.updated_at);
813 let their_ts = their_str
814 .as_deref()
815 .and_then(|s| Memory::from_markdown(s).ok())
816 .map(|m| m.metadata.updated_at);
817
818 let (chosen_bytes, label): (Vec<u8>, String) =
820 match (our_str.as_deref(), their_str.as_deref()) {
821 (Some(ours), Some(theirs)) => match (our_ts, their_ts) {
822 (Some(ot), Some(tt)) if tt > ot => (
823 theirs.as_bytes().to_vec(),
824 format!("theirs (updated_at: {})", tt),
825 ),
826 (Some(ot), _) => (
827 ours.as_bytes().to_vec(),
828 format!("ours (updated_at: {})", ot),
829 ),
830 _ => (
831 ours.as_bytes().to_vec(),
832 "ours (timestamp unparseable)".to_string(),
833 ),
834 },
835 (Some(ours), None) => (
836 ours.as_bytes().to_vec(),
837 "ours (theirs missing)".to_string(),
838 ),
839 (None, Some(theirs)) => (
840 theirs.as_bytes().to_vec(),
841 "theirs (ours missing)".to_string(),
842 ),
843 (None, None) => {
844 match (info.our_blob.as_deref(), info.their_blob.as_deref()) {
846 (Some(ours), _) => {
847 (ours.to_vec(), "ours (binary/non-UTF-8)".to_string())
848 }
849 (_, Some(theirs)) => {
850 (theirs.to_vec(), "theirs (binary/non-UTF-8)".to_string())
851 }
852 (None, None) => {
853 warn!(
856 "conflict at '{}': both sides missing — removing from index",
857 info.path.display()
858 );
859 let relative = info.path.strip_prefix(&self.root).map_err(|e| {
860 MemoryError::InvalidInput {
861 reason: format!(
862 "path strip error during conflict resolution: {}",
863 e
864 ),
865 }
866 })?;
867 index.conflict_remove(relative)?;
868 resolved += 1;
869 continue;
870 }
871 }
872 }
873 };
874
875 warn!(
876 "conflict resolved: {} — kept {}",
877 info.path.display(),
878 label
879 );
880
881 self.assert_within_root(&info.path)?;
885 if let Some(parent) = info.path.parent() {
886 std::fs::create_dir_all(parent)?;
887 }
888 self.write_memory_file(&info.path, &chosen_bytes)?;
889
890 let relative =
892 info.path
893 .strip_prefix(&self.root)
894 .map_err(|e| MemoryError::InvalidInput {
895 reason: format!("path strip error during conflict resolution: {}", e),
896 })?;
897 index.add_path(relative)?;
898
899 resolved += 1;
900 }
901
902 Ok(resolved)
903 }
904
905 fn signature<'r>(&self, repo: &'r Repository) -> Result<Signature<'r>, MemoryError> {
906 let sig = repo
908 .signature()
909 .or_else(|_| Signature::now("memory-mcp", "memory-mcp@local"))?;
910 Ok(sig)
911 }
912
913 fn git_add_and_commit(
915 &self,
916 repo: &Repository,
917 file_path: &Path,
918 message: &str,
919 ) -> Result<(), MemoryError> {
920 let relative =
921 file_path
922 .strip_prefix(&self.root)
923 .map_err(|e| MemoryError::InvalidInput {
924 reason: format!("path strip error: {}", e),
925 })?;
926
927 let mut index = repo.index()?;
928 index.add_path(relative)?;
929 index.write()?;
930
931 let tree_oid = index.write_tree()?;
932 let tree = repo.find_tree(tree_oid)?;
933 let sig = self.signature(repo)?;
934
935 match repo.head() {
936 Ok(head) => {
937 let parent_commit = head.peel_to_commit()?;
938 repo.commit(Some("HEAD"), &sig, &sig, message, &tree, &[&parent_commit])?;
939 }
940 Err(e) if e.code() == ErrorCode::UnbornBranch || e.code() == ErrorCode::NotFound => {
941 repo.commit(Some("HEAD"), &sig, &sig, message, &tree, &[])?;
943 }
944 Err(e) => return Err(MemoryError::Git(e)),
945 }
946
947 Ok(())
948 }
949
950 fn assert_within_root(&self, path: &Path) -> Result<(), MemoryError> {
953 let parent = path.parent().unwrap_or(path);
956 let filename = path.file_name().ok_or_else(|| MemoryError::InvalidInput {
957 reason: "path has no filename component".to_string(),
958 })?;
959
960 let canon_parent = {
963 let mut p = parent.to_path_buf();
964 let mut suffixes: Vec<std::ffi::OsString> = Vec::new();
965 loop {
966 match p.canonicalize() {
967 Ok(c) => {
968 let mut full = c;
969 for s in suffixes.into_iter().rev() {
970 full.push(s);
971 }
972 break full;
973 }
974 Err(_) => {
975 if let Some(name) = p.file_name() {
976 suffixes.push(name.to_os_string());
977 }
978 match p.parent() {
979 Some(par) => p = par.to_path_buf(),
980 None => {
981 return Err(MemoryError::InvalidInput {
982 reason: "cannot resolve any ancestor of path".into(),
983 });
984 }
985 }
986 }
987 }
988 }
989 };
990
991 let resolved = canon_parent.join(filename);
992
993 let canon_root = self
994 .root
995 .canonicalize()
996 .map_err(|e| MemoryError::InvalidInput {
997 reason: format!("cannot canonicalize repo root: {}", e),
998 })?;
999
1000 if !resolved.starts_with(&canon_root) {
1001 return Err(MemoryError::InvalidInput {
1002 reason: format!(
1003 "path '{}' escapes repository root '{}'",
1004 resolved.display(),
1005 canon_root.display()
1006 ),
1007 });
1008 }
1009
1010 {
1015 let mut probe = canon_root.clone();
1016 let relative =
1018 resolved
1019 .strip_prefix(&canon_root)
1020 .map_err(|e| MemoryError::InvalidInput {
1021 reason: format!("path strip error: {}", e),
1022 })?;
1023 for component in relative.components() {
1024 probe.push(component);
1025 if (probe.exists() || probe.symlink_metadata().is_ok())
1027 && probe
1028 .symlink_metadata()
1029 .map(|m| m.file_type().is_symlink())
1030 .unwrap_or(false)
1031 {
1032 return Err(MemoryError::InvalidInput {
1033 reason: format!(
1034 "path component '{}' is a symlink, which is not allowed",
1035 probe.display()
1036 ),
1037 });
1038 }
1039 }
1040 }
1041
1042 Ok(())
1043 }
1044
1045 fn write_memory_file(&self, path: &Path, data: &[u8]) -> Result<(), MemoryError> {
1050 #[cfg(unix)]
1051 {
1052 use std::io::Write as _;
1053 use std::os::unix::fs::OpenOptionsExt as _;
1054 let mut f = std::fs::OpenOptions::new()
1055 .write(true)
1056 .create(true)
1057 .truncate(true)
1058 .custom_flags(libc::O_NOFOLLOW)
1059 .open(path)?;
1060 f.write_all(data)?;
1061 Ok(())
1062 }
1063 #[cfg(not(unix))]
1064 {
1065 std::fs::write(path, data)?;
1066 Ok(())
1067 }
1068 }
1069
1070 fn read_memory_file(&self, path: &Path) -> Result<String, MemoryError> {
1075 #[cfg(unix)]
1076 {
1077 use std::io::Read as _;
1078 use std::os::unix::fs::OpenOptionsExt as _;
1079 let mut f = std::fs::OpenOptions::new()
1080 .read(true)
1081 .custom_flags(libc::O_NOFOLLOW)
1082 .open(path)?;
1083 let mut buf = String::new();
1084 f.read_to_string(&mut buf)?;
1085 Ok(buf)
1086 }
1087 #[cfg(not(unix))]
1088 {
1089 Ok(std::fs::read_to_string(path)?)
1090 }
1091 }
1092}
1093
1094#[cfg(test)]
1099mod tests {
1100 use super::*;
1101 use crate::auth::AuthProvider;
1102 use crate::types::{Memory, MemoryMetadata, PullResult, Scope};
1103 use std::sync::Arc;
1104
1105 fn test_auth() -> AuthProvider {
1106 AuthProvider::with_token("test-token-unused-for-file-remotes")
1107 }
1108
1109 fn make_memory(name: &str, content: &str, updated_at_secs: i64) -> Memory {
1110 let meta = MemoryMetadata {
1111 tags: vec![],
1112 scope: Scope::Global,
1113 created_at: chrono::DateTime::from_timestamp(1_700_000_000, 0).unwrap(),
1114 updated_at: chrono::DateTime::from_timestamp(updated_at_secs, 0).unwrap(),
1115 source: None,
1116 };
1117 Memory::new(name.to_string(), content.to_string(), meta)
1118 }
1119
1120 fn setup_bare_remote() -> (tempfile::TempDir, String) {
1121 let dir = tempfile::tempdir().expect("failed to create temp dir");
1122 git2::Repository::init_bare(dir.path()).expect("failed to init bare repo");
1123 let url = format!("file://{}", dir.path().display());
1124 (dir, url)
1125 }
1126
1127 fn open_repo(dir: &tempfile::TempDir, remote_url: Option<&str>) -> Arc<MemoryRepo> {
1128 Arc::new(MemoryRepo::init_or_open(dir.path(), remote_url).expect("failed to init repo"))
1129 }
1130
1131 #[test]
1134 fn redact_url_strips_userinfo() {
1135 assert_eq!(
1136 redact_url("https://user:ghp_token123@github.com/org/repo.git"),
1137 "https://[REDACTED]@github.com/org/repo.git"
1138 );
1139 }
1140
1141 #[test]
1142 fn redact_url_no_at_passthrough() {
1143 let url = "https://github.com/org/repo.git";
1144 assert_eq!(redact_url(url), url);
1145 }
1146
1147 #[test]
1148 fn redact_url_file_protocol_passthrough() {
1149 let url = "file:///tmp/bare.git";
1150 assert_eq!(redact_url(url), url);
1151 }
1152
1153 #[test]
1156 fn assert_within_root_accepts_valid_path() {
1157 let dir = tempfile::tempdir().unwrap();
1158 let repo = MemoryRepo::init_or_open(dir.path(), None).unwrap();
1159 let valid = dir.path().join("global").join("my-memory.md");
1160 std::fs::create_dir_all(valid.parent().unwrap()).unwrap();
1162 assert!(repo.assert_within_root(&valid).is_ok());
1163 }
1164
1165 #[test]
1166 fn assert_within_root_rejects_escape() {
1167 let dir = tempfile::tempdir().unwrap();
1168 let repo = MemoryRepo::init_or_open(dir.path(), None).unwrap();
1169 let _evil = dir
1172 .path()
1173 .join("..")
1174 .join("..")
1175 .join("..")
1176 .join("tmp")
1177 .join("evil.md");
1178 let outside = std::path::PathBuf::from("/tmp/definitely-outside");
1182 assert!(repo.assert_within_root(&outside).is_err());
1183 }
1184
1185 #[tokio::test]
1188 async fn push_local_only_returns_ok() {
1189 let dir = tempfile::tempdir().unwrap();
1190 let repo = open_repo(&dir, None);
1191 let auth = test_auth();
1192 let result = repo.push(&auth, "main").await;
1194 assert!(result.is_ok());
1195 }
1196
1197 #[tokio::test]
1198 async fn pull_local_only_returns_no_remote() {
1199 let dir = tempfile::tempdir().unwrap();
1200 let repo = open_repo(&dir, None);
1201 let auth = test_auth();
1202 let result = repo.pull(&auth, "main").await.unwrap();
1203 assert!(matches!(result, PullResult::NoRemote));
1204 }
1205
1206 #[tokio::test]
1209 async fn push_to_bare_remote() {
1210 let (_remote_dir, remote_url) = setup_bare_remote();
1211 let local_dir = tempfile::tempdir().unwrap();
1212 let repo = open_repo(&local_dir, Some(&remote_url));
1213 let auth = test_auth();
1214
1215 let mem = make_memory("test-push", "push content", 1_700_000_000);
1217 repo.save_memory(&mem).await.unwrap();
1218
1219 repo.push(&auth, "main").await.unwrap();
1221
1222 let bare = git2::Repository::open_bare(_remote_dir.path()).unwrap();
1224 let head = bare.find_reference("refs/heads/main").unwrap();
1225 let commit = head.peel_to_commit().unwrap();
1226 assert!(commit.message().unwrap().contains("test-push"));
1227 }
1228
1229 #[tokio::test]
1230 async fn pull_from_empty_bare_remote_returns_up_to_date() {
1231 let (_remote_dir, remote_url) = setup_bare_remote();
1232 let local_dir = tempfile::tempdir().unwrap();
1233 let repo = open_repo(&local_dir, Some(&remote_url));
1234 let auth = test_auth();
1235
1236 let mem = make_memory("seed", "seed content", 1_700_000_000);
1238 repo.save_memory(&mem).await.unwrap();
1239
1240 let result = repo.pull(&auth, "main").await.unwrap();
1242 assert!(matches!(result, PullResult::UpToDate));
1243 }
1244
1245 #[tokio::test]
1246 async fn pull_fast_forward() {
1247 let (_remote_dir, remote_url) = setup_bare_remote();
1248 let auth = test_auth();
1249
1250 let dir_a = tempfile::tempdir().unwrap();
1252 let repo_a = open_repo(&dir_a, Some(&remote_url));
1253 let mem = make_memory("from-a", "content from A", 1_700_000_000);
1254 repo_a.save_memory(&mem).await.unwrap();
1255 repo_a.push(&auth, "main").await.unwrap();
1256
1257 let dir_b = tempfile::tempdir().unwrap();
1259 let repo_b = open_repo(&dir_b, Some(&remote_url));
1260 let seed = make_memory("seed-b", "seed", 1_700_000_000);
1262 repo_b.save_memory(&seed).await.unwrap();
1263
1264 let result = repo_b.pull(&auth, "main").await.unwrap();
1265 assert!(
1266 matches!(
1267 result,
1268 PullResult::FastForward { .. } | PullResult::Merged { .. }
1269 ),
1270 "expected fast-forward or merge, got {:?}",
1271 result
1272 );
1273
1274 let file = dir_b.path().join("global").join("from-a.md");
1276 assert!(file.exists(), "from-a.md should exist in repo B after pull");
1277 }
1278
1279 #[tokio::test]
1280 async fn pull_up_to_date_after_push() {
1281 let (_remote_dir, remote_url) = setup_bare_remote();
1282 let local_dir = tempfile::tempdir().unwrap();
1283 let repo = open_repo(&local_dir, Some(&remote_url));
1284 let auth = test_auth();
1285
1286 let mem = make_memory("synced", "synced content", 1_700_000_000);
1287 repo.save_memory(&mem).await.unwrap();
1288 repo.push(&auth, "main").await.unwrap();
1289
1290 let result = repo.pull(&auth, "main").await.unwrap();
1292 assert!(matches!(result, PullResult::UpToDate));
1293 }
1294
1295 #[tokio::test]
1298 async fn pull_merge_conflict_theirs_newer_wins() {
1299 let (_remote_dir, remote_url) = setup_bare_remote();
1300 let auth = test_auth();
1301
1302 let dir_a = tempfile::tempdir().unwrap();
1304 let repo_a = open_repo(&dir_a, Some(&remote_url));
1305 let mem_a1 = make_memory("shared", "version from A initial", 1_700_000_100);
1306 repo_a.save_memory(&mem_a1).await.unwrap();
1307 repo_a.push(&auth, "main").await.unwrap();
1308
1309 let dir_b = tempfile::tempdir().unwrap();
1311 let repo_b = open_repo(&dir_b, Some(&remote_url));
1312 let seed = make_memory("seed-b", "seed", 1_700_000_000);
1313 repo_b.save_memory(&seed).await.unwrap();
1314 repo_b.pull(&auth, "main").await.unwrap();
1315
1316 let mem_b = make_memory("shared", "version from B (newer)", 1_700_000_300);
1317 repo_b.save_memory(&mem_b).await.unwrap();
1318 repo_b.push(&auth, "main").await.unwrap();
1319
1320 let mem_a2 = make_memory("shared", "version from A (older)", 1_700_000_200);
1322 repo_a.save_memory(&mem_a2).await.unwrap();
1323 let result = repo_a.pull(&auth, "main").await.unwrap();
1324
1325 assert!(
1326 matches!(result, PullResult::Merged { conflicts_resolved, .. } if conflicts_resolved >= 1),
1327 "expected merge with conflicts resolved, got {:?}",
1328 result
1329 );
1330
1331 let file = dir_a.path().join("global").join("shared.md");
1333 let content = std::fs::read_to_string(&file).unwrap();
1334 assert!(
1335 content.contains("version from B (newer)"),
1336 "expected B's version to win (newer timestamp), got: {}",
1337 content
1338 );
1339 }
1340
1341 #[tokio::test]
1342 async fn pull_merge_conflict_ours_newer_wins() {
1343 let (_remote_dir, remote_url) = setup_bare_remote();
1344 let auth = test_auth();
1345
1346 let dir_a = tempfile::tempdir().unwrap();
1348 let repo_a = open_repo(&dir_a, Some(&remote_url));
1349 let mem_a1 = make_memory("shared", "version from A initial", 1_700_000_100);
1350 repo_a.save_memory(&mem_a1).await.unwrap();
1351 repo_a.push(&auth, "main").await.unwrap();
1352
1353 let dir_b = tempfile::tempdir().unwrap();
1355 let repo_b = open_repo(&dir_b, Some(&remote_url));
1356 let seed = make_memory("seed-b", "seed", 1_700_000_000);
1357 repo_b.save_memory(&seed).await.unwrap();
1358 repo_b.pull(&auth, "main").await.unwrap();
1359
1360 let mem_b = make_memory("shared", "version from B (older)", 1_700_000_200);
1361 repo_b.save_memory(&mem_b).await.unwrap();
1362 repo_b.push(&auth, "main").await.unwrap();
1363
1364 let mem_a2 = make_memory("shared", "version from A (newer)", 1_700_000_300);
1366 repo_a.save_memory(&mem_a2).await.unwrap();
1367 let result = repo_a.pull(&auth, "main").await.unwrap();
1368
1369 assert!(
1370 matches!(result, PullResult::Merged { conflicts_resolved, .. } if conflicts_resolved >= 1),
1371 "expected merge with conflicts resolved, got {:?}",
1372 result
1373 );
1374
1375 let file = dir_a.path().join("global").join("shared.md");
1377 let content = std::fs::read_to_string(&file).unwrap();
1378 assert!(
1379 content.contains("version from A (newer)"),
1380 "expected A's version to win (newer timestamp), got: {}",
1381 content
1382 );
1383 }
1384
1385 #[tokio::test]
1386 async fn pull_merge_no_conflict_different_files() {
1387 let (_remote_dir, remote_url) = setup_bare_remote();
1388 let auth = test_auth();
1389
1390 let dir_a = tempfile::tempdir().unwrap();
1392 let repo_a = open_repo(&dir_a, Some(&remote_url));
1393 let mem_a = make_memory("mem-a", "from A", 1_700_000_100);
1394 repo_a.save_memory(&mem_a).await.unwrap();
1395 repo_a.push(&auth, "main").await.unwrap();
1396
1397 let dir_b = tempfile::tempdir().unwrap();
1399 let repo_b = open_repo(&dir_b, Some(&remote_url));
1400 let seed = make_memory("seed-b", "seed", 1_700_000_000);
1401 repo_b.save_memory(&seed).await.unwrap();
1402 repo_b.pull(&auth, "main").await.unwrap();
1403 let mem_b = make_memory("mem-b", "from B", 1_700_000_200);
1404 repo_b.save_memory(&mem_b).await.unwrap();
1405 repo_b.push(&auth, "main").await.unwrap();
1406
1407 let mem_a2 = make_memory("mem-a2", "also from A", 1_700_000_300);
1409 repo_a.save_memory(&mem_a2).await.unwrap();
1410 let result = repo_a.pull(&auth, "main").await.unwrap();
1411
1412 assert!(
1413 matches!(
1414 result,
1415 PullResult::Merged {
1416 conflicts_resolved: 0,
1417 ..
1418 }
1419 ),
1420 "expected clean merge, got {:?}",
1421 result
1422 );
1423
1424 assert!(dir_a.path().join("global").join("mem-b.md").exists());
1426 }
1427
1428 fn commit_file(repo: &Arc<MemoryRepo>, rel_path: &str, content: &str) -> [u8; 20] {
1432 let inner = repo.inner.lock().expect("lock poisoned");
1433 let full_path = repo.root.join(rel_path);
1434 if let Some(parent) = full_path.parent() {
1435 std::fs::create_dir_all(parent).unwrap();
1436 }
1437 std::fs::write(&full_path, content).unwrap();
1438
1439 let mut index = inner.index().unwrap();
1440 index.add_path(std::path::Path::new(rel_path)).unwrap();
1441 index.write().unwrap();
1442 let tree_oid = index.write_tree().unwrap();
1443 let tree = inner.find_tree(tree_oid).unwrap();
1444 let sig = git2::Signature::now("test", "test@test.com").unwrap();
1445
1446 let oid = match inner.head() {
1447 Ok(head) => {
1448 let parent = head.peel_to_commit().unwrap();
1449 inner
1450 .commit(Some("HEAD"), &sig, &sig, "test commit", &tree, &[&parent])
1451 .unwrap()
1452 }
1453 Err(_) => inner
1454 .commit(Some("HEAD"), &sig, &sig, "initial commit", &tree, &[])
1455 .unwrap(),
1456 };
1457
1458 let mut buf = [0u8; 20];
1459 buf.copy_from_slice(oid.as_bytes());
1460 buf
1461 }
1462
1463 #[test]
1464 fn diff_changed_memories_detects_added_global() {
1465 let dir = tempfile::tempdir().unwrap();
1466 let repo = open_repo(&dir, None);
1467
1468 let old_oid = {
1470 let inner = repo.inner.lock().unwrap();
1471 let head = inner.head().unwrap();
1472 let mut buf = [0u8; 20];
1473 buf.copy_from_slice(head.peel_to_commit().unwrap().id().as_bytes());
1474 buf
1475 };
1476
1477 let new_oid = commit_file(&repo, "global/my-note.md", "# content");
1478
1479 let changes = repo.diff_changed_memories(old_oid, new_oid).unwrap();
1480 assert_eq!(changes.upserted, vec!["global/my-note".to_string()]);
1481 assert!(changes.removed.is_empty());
1482 }
1483
1484 #[test]
1485 fn diff_changed_memories_detects_deleted() {
1486 let dir = tempfile::tempdir().unwrap();
1487 let repo = open_repo(&dir, None);
1488
1489 let first_oid = commit_file(&repo, "global/to-delete.md", "hello");
1490 let second_oid = {
1491 let inner = repo.inner.lock().unwrap();
1492 let full_path = dir.path().join("global/to-delete.md");
1493 std::fs::remove_file(&full_path).unwrap();
1494 let mut index = inner.index().unwrap();
1495 index
1496 .remove_path(std::path::Path::new("global/to-delete.md"))
1497 .unwrap();
1498 index.write().unwrap();
1499 let tree_oid = index.write_tree().unwrap();
1500 let tree = inner.find_tree(tree_oid).unwrap();
1501 let sig = git2::Signature::now("test", "test@test.com").unwrap();
1502 let parent = inner.head().unwrap().peel_to_commit().unwrap();
1503 let oid = inner
1504 .commit(Some("HEAD"), &sig, &sig, "delete file", &tree, &[&parent])
1505 .unwrap();
1506 let mut buf = [0u8; 20];
1507 buf.copy_from_slice(oid.as_bytes());
1508 buf
1509 };
1510
1511 let changes = repo.diff_changed_memories(first_oid, second_oid).unwrap();
1512 assert!(changes.upserted.is_empty());
1513 assert_eq!(changes.removed, vec!["global/to-delete".to_string()]);
1514 }
1515
1516 #[test]
1517 fn diff_changed_memories_ignores_non_md_files() {
1518 let dir = tempfile::tempdir().unwrap();
1519 let repo = open_repo(&dir, None);
1520
1521 let old_oid = {
1522 let inner = repo.inner.lock().unwrap();
1523 let mut buf = [0u8; 20];
1524 buf.copy_from_slice(
1525 inner
1526 .head()
1527 .unwrap()
1528 .peel_to_commit()
1529 .unwrap()
1530 .id()
1531 .as_bytes(),
1532 );
1533 buf
1534 };
1535
1536 let _ = commit_file(&repo, "global/config.json", "{}");
1538 let new_oid = commit_file(&repo, "other/note.md", "# ignored");
1539
1540 let changes = repo.diff_changed_memories(old_oid, new_oid).unwrap();
1541 assert!(
1542 changes.upserted.is_empty(),
1543 "should ignore non-.md and out-of-scope files"
1544 );
1545 assert!(changes.removed.is_empty());
1546 }
1547
1548 #[test]
1549 fn diff_changed_memories_detects_modified() {
1550 let dir = tempfile::tempdir().unwrap();
1551 let repo = open_repo(&dir, None);
1552
1553 let first_oid = commit_file(&repo, "projects/myproject/note.md", "version 1");
1554 let second_oid = commit_file(&repo, "projects/myproject/note.md", "version 2");
1555
1556 let changes = repo.diff_changed_memories(first_oid, second_oid).unwrap();
1557 assert_eq!(
1558 changes.upserted,
1559 vec!["projects/myproject/note".to_string()]
1560 );
1561 assert!(changes.removed.is_empty());
1562 }
1563
1564 #[test]
1567 fn diff_changed_memories_zero_oid_treats_all_as_added() {
1568 let dir = tempfile::tempdir().unwrap();
1569 let repo = open_repo(&dir, None);
1570
1571 let new_oid = commit_file(&repo, "global/first-memory.md", "# Hello");
1573
1574 let old_oid = [0u8; 20];
1576
1577 let changes = repo.diff_changed_memories(old_oid, new_oid).unwrap();
1578 assert_eq!(
1579 changes.upserted,
1580 vec!["global/first-memory".to_string()],
1581 "zero OID: all new-tree files should be additions"
1582 );
1583 assert!(changes.removed.is_empty(), "zero OID: no removals expected");
1584 }
1585}