1use chrono::{DateTime, Utc};
2use rmcp::schemars;
3use serde::{Deserialize, Serialize};
4use std::{fmt, str::FromStr};
5use uuid::Uuid;
6
7use crate::error::MemoryError;
8
9pub fn validate_name(name: &str) -> Result<(), MemoryError> {
19 if name.is_empty() {
20 return Err(MemoryError::InvalidInput {
21 reason: "name must not be empty".to_string(),
22 });
23 }
24
25 let components: Vec<&str> = name.split('/').collect();
26
27 if components.len() > 3 {
28 return Err(MemoryError::InvalidInput {
29 reason: format!("name '{}' exceeds maximum nesting depth of 3", name),
30 });
31 }
32
33 for component in &components {
34 if component.is_empty() {
35 return Err(MemoryError::InvalidInput {
36 reason: format!("name '{}' contains an empty path component", name),
37 });
38 }
39 if component.starts_with('.') {
40 return Err(MemoryError::InvalidInput {
41 reason: format!(
42 "name '{}' contains a dot-prefixed component '{}'",
43 name, component
44 ),
45 });
46 }
47 if !component
48 .chars()
49 .all(|c| c.is_alphanumeric() || c == '-' || c == '_' || c == '.')
50 {
51 return Err(MemoryError::InvalidInput {
52 reason: format!(
53 "name '{}' contains disallowed characters in component '{}'",
54 name, component
55 ),
56 });
57 }
58 }
59
60 Ok(())
61}
62
63pub fn validate_branch_name(branch: &str) -> Result<(), MemoryError> {
68 if branch.is_empty() {
69 return Err(MemoryError::InvalidInput {
70 reason: "branch name cannot be empty".into(),
71 });
72 }
73 if branch.contains("..") {
74 return Err(MemoryError::InvalidInput {
75 reason: "branch name cannot contain '..'".into(),
76 });
77 }
78 let invalid_chars = [' ', '~', '^', ':', '?', '*', '[', '\\'];
79 for c in branch.chars() {
80 if c.is_ascii_control() || invalid_chars.contains(&c) {
81 return Err(MemoryError::InvalidInput {
82 reason: format!("branch name contains invalid character '{}'", c),
83 });
84 }
85 }
86 if branch.starts_with('/')
87 || branch.ends_with('/')
88 || branch.ends_with('.')
89 || branch.starts_with('.')
90 {
91 return Err(MemoryError::InvalidInput {
92 reason: "branch name has invalid start/end character".into(),
93 });
94 }
95 if branch.contains("//") {
96 return Err(MemoryError::InvalidInput {
97 reason: "branch name contains consecutive slashes".into(),
98 });
99 }
100 Ok(())
101}
102
103#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
112#[serde(tag = "type", content = "name")]
113#[non_exhaustive]
114pub enum Scope {
115 Global,
117 Project(String),
119}
120
121impl Scope {
122 pub fn dir_prefix(&self) -> String {
124 match self {
125 Scope::Global => "global".to_string(),
126 Scope::Project(name) => format!("projects/{}", name),
127 }
128 }
129}
130
131impl fmt::Display for Scope {
132 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
133 match self {
134 Scope::Global => write!(f, "global"),
135 Scope::Project(name) => write!(f, "project:{}", name),
136 }
137 }
138}
139
140impl FromStr for Scope {
141 type Err = MemoryError;
142
143 fn from_str(s: &str) -> Result<Self, Self::Err> {
147 if s == "global" {
148 return Ok(Scope::Global);
149 }
150 if let Some(name) = s.strip_prefix("project:") {
151 if name.is_empty() {
152 return Err(MemoryError::InvalidInput {
153 reason: "project scope requires a non-empty name after 'project:'".to_string(),
154 });
155 }
156 if name.contains('/') {
157 return Err(MemoryError::InvalidInput {
158 reason: "project name must not contain '/'".to_string(),
159 });
160 }
161 validate_name(name)?;
162 return Ok(Scope::Project(name.to_string()));
163 }
164 Err(MemoryError::InvalidInput {
165 reason: format!(
166 "unrecognised scope '{}'; expected 'global' or 'project:<name>'",
167 s
168 ),
169 })
170 }
171}
172
173#[derive(Debug, Clone, Serialize, Deserialize)]
179pub struct MemoryMetadata {
180 pub tags: Vec<String>,
182 pub scope: Scope,
184 pub created_at: DateTime<Utc>,
186 pub updated_at: DateTime<Utc>,
188 pub source: Option<String>,
190}
191
192impl MemoryMetadata {
193 pub fn new(scope: Scope, tags: Vec<String>, source: Option<String>) -> Self {
195 let now = Utc::now();
196 Self {
197 tags,
198 scope,
199 created_at: now,
200 updated_at: now,
201 source,
202 }
203 }
204}
205
206#[derive(Debug, Clone, Serialize, Deserialize)]
212pub struct Memory {
213 pub id: String,
215 pub name: String,
217 pub content: String,
219 pub metadata: MemoryMetadata,
221}
222
223impl Memory {
224 pub fn new(name: String, content: String, metadata: MemoryMetadata) -> Self {
226 Self {
227 id: Uuid::new_v4().to_string(),
228 name,
229 content,
230 metadata,
231 }
232 }
233
234 pub fn to_markdown(&self) -> Result<String, MemoryError> {
245 #[derive(Serialize)]
246 struct Frontmatter<'a> {
247 id: &'a str,
248 name: &'a str,
249 tags: &'a [String],
250 scope: &'a Scope,
251 created_at: &'a DateTime<Utc>,
252 updated_at: &'a DateTime<Utc>,
253 #[serde(skip_serializing_if = "Option::is_none")]
254 source: Option<&'a str>,
255 }
256
257 let fm = Frontmatter {
258 id: &self.id,
259 name: &self.name,
260 tags: &self.metadata.tags,
261 scope: &self.metadata.scope,
262 created_at: &self.metadata.created_at,
263 updated_at: &self.metadata.updated_at,
264 source: self.metadata.source.as_deref(),
265 };
266
267 let yaml = serde_yaml_ng::to_string(&fm)?;
268 Ok(format!("---\n{}---\n\n{}", yaml, self.content))
269 }
270
271 pub fn from_markdown(raw: &str) -> Result<Self, MemoryError> {
273 let rest = raw
275 .strip_prefix("---\n")
276 .ok_or_else(|| MemoryError::InvalidInput {
277 reason: "missing opening frontmatter delimiter".to_string(),
278 })?;
279
280 let end_marker = rest
282 .find("\n---\n")
283 .ok_or_else(|| MemoryError::InvalidInput {
284 reason: "missing closing frontmatter delimiter".to_string(),
285 })?;
286
287 let yaml_str = &rest[..end_marker];
288 let body = rest[end_marker + 5..].trim_start_matches('\n');
290
291 #[derive(Deserialize)]
292 struct Frontmatter {
293 id: String,
294 name: String,
295 tags: Vec<String>,
296 scope: Scope,
297 created_at: DateTime<Utc>,
298 updated_at: DateTime<Utc>,
299 source: Option<String>,
300 }
301
302 let fm: Frontmatter = serde_yaml_ng::from_str(yaml_str)?;
303
304 Ok(Memory {
305 id: fm.id,
306 name: fm.name,
307 content: body.to_string(),
308 metadata: MemoryMetadata {
309 tags: fm.tags,
310 scope: fm.scope,
311 created_at: fm.created_at,
312 updated_at: fm.updated_at,
313 source: fm.source,
314 },
315 })
316 }
317}
318
319#[derive(Debug, Clone, PartialEq, Eq)]
328pub enum ScopeFilter {
329 GlobalOnly,
331 ProjectAndGlobal(String),
333 All,
335}
336
337pub fn parse_scope_filter(scope: Option<&str>) -> Result<ScopeFilter, MemoryError> {
346 match scope {
347 None | Some("global") => Ok(ScopeFilter::GlobalOnly),
348 Some("all") => Ok(ScopeFilter::All),
349 Some(s) => {
350 let parsed = s.parse::<Scope>()?;
351 match parsed {
352 Scope::Project(name) => Ok(ScopeFilter::ProjectAndGlobal(name)),
353 Scope::Global => Ok(ScopeFilter::GlobalOnly),
356 }
357 }
358 }
359}
360
361pub fn parse_scope(scope: Option<&str>) -> Result<Scope, MemoryError> {
367 match scope {
368 None => Ok(Scope::Global),
369 Some(s) => s.parse::<Scope>(),
370 }
371}
372
373pub fn parse_qualified_name(qualified: &str) -> Result<(Scope, String), MemoryError> {
376 if let Some(rest) = qualified.strip_prefix("global/") {
377 validate_name(rest)?;
378 return Ok((Scope::Global, rest.to_string()));
379 }
380 if let Some(rest) = qualified.strip_prefix("projects/") {
381 if let Some(slash_pos) = rest.find('/') {
383 let project = &rest[..slash_pos];
384 let name = &rest[slash_pos + 1..];
385 if project.is_empty() || name.is_empty() {
386 return Err(MemoryError::InvalidInput {
387 reason: format!(
388 "malformed qualified name '{}': project or memory name is empty",
389 qualified
390 ),
391 });
392 }
393 validate_name(project)?;
394 validate_name(name)?;
395 return Ok((Scope::Project(project.to_string()), name.to_string()));
396 }
397 return Err(MemoryError::InvalidInput {
398 reason: format!(
399 "malformed qualified name '{}': missing memory name after project",
400 qualified
401 ),
402 });
403 }
404 Err(MemoryError::InvalidInput {
405 reason: format!(
406 "malformed qualified name '{}': must start with 'global/' or 'projects/'",
407 qualified
408 ),
409 })
410}
411
412#[derive(Debug, Deserialize, schemars::JsonSchema)]
418pub struct RememberArgs {
419 pub content: String,
421 pub name: String,
423 #[serde(default)]
425 pub tags: Vec<String>,
426 #[serde(default)]
428 pub scope: Option<String>,
429 #[serde(default)]
431 pub source: Option<String>,
432}
433
434#[derive(Debug, Deserialize, schemars::JsonSchema)]
436pub struct RecallArgs {
437 pub query: String,
439 #[serde(default)]
441 pub scope: Option<String>,
442 #[serde(default)]
444 pub limit: Option<usize>,
445}
446
447#[derive(Debug, Deserialize, schemars::JsonSchema)]
449pub struct ForgetArgs {
450 pub name: String,
452 #[serde(default)]
454 pub scope: Option<String>,
455}
456
457#[derive(Debug, Deserialize, schemars::JsonSchema)]
459pub struct EditArgs {
460 pub name: String,
462 #[serde(default)]
464 pub content: Option<String>,
465 #[serde(default)]
467 pub tags: Option<Vec<String>>,
468 #[serde(default)]
470 pub scope: Option<String>,
471}
472
473#[derive(Debug, Deserialize, schemars::JsonSchema)]
475pub struct ListArgs {
476 #[serde(default)]
478 pub scope: Option<String>,
479}
480
481#[derive(Debug, Deserialize, schemars::JsonSchema)]
483pub struct ReadArgs {
484 pub name: String,
486 #[serde(default)]
488 pub scope: Option<String>,
489}
490
491#[derive(Debug, Deserialize, schemars::JsonSchema)]
493pub struct SyncArgs {
494 #[serde(default)]
496 pub pull_first: Option<bool>,
497}
498
499#[derive(Debug)]
505#[non_exhaustive]
506pub enum PullResult {
507 NoRemote,
509 UpToDate,
511 FastForward {
513 old_head: [u8; 20],
515 new_head: [u8; 20],
517 },
518 Merged {
520 conflicts_resolved: usize,
522 old_head: [u8; 20],
524 new_head: [u8; 20],
526 },
527}
528
529#[derive(Debug, Default)]
535pub struct ChangedMemories {
536 pub upserted: Vec<String>,
538 pub removed: Vec<String>,
540}
541
542impl ChangedMemories {
543 pub fn is_empty(&self) -> bool {
545 self.upserted.is_empty() && self.removed.is_empty()
546 }
547}
548
549#[derive(Debug, Default)]
555pub struct ReindexStats {
556 pub added: usize,
558 pub updated: usize,
560 pub removed: usize,
562 pub errors: usize,
564}
565
566use std::sync::Arc;
571
572use crate::{
573 auth::AuthProvider, embedding::EmbeddingBackend, health::HealthRegistry, index::VectorStore,
574 repo::MemoryRepo,
575};
576
577#[non_exhaustive]
582pub struct AppState {
583 pub repo: Arc<MemoryRepo>,
585 pub embedding: Box<dyn EmbeddingBackend>,
587 pub index: Box<dyn VectorStore>,
589 pub auth: AuthProvider,
591 pub branch: String,
593 pub health: HealthRegistry,
595}
596
597impl AppState {
598 pub fn new(
600 repo: Arc<MemoryRepo>,
601 branch: String,
602 embedding: Box<dyn EmbeddingBackend>,
603 index: Box<dyn VectorStore>,
604 auth: AuthProvider,
605 health: HealthRegistry,
606 ) -> Self {
607 Self {
608 repo,
609 embedding,
610 index,
611 auth,
612 branch,
613 health,
614 }
615 }
616}
617
618#[cfg(test)]
623mod tests {
624 use super::*;
625
626 fn make_memory() -> Memory {
627 let meta = MemoryMetadata {
628 tags: vec!["test".to_string(), "round-trip".to_string()],
629 scope: Scope::Project("my-project".to_string()),
630 created_at: DateTime::from_timestamp(1_700_000_000, 0).unwrap(),
631 updated_at: DateTime::from_timestamp(1_700_000_100, 0).unwrap(),
632 source: Some("unit-test".to_string()),
633 };
634 Memory {
635 id: "550e8400-e29b-41d4-a716-446655440000".to_string(),
636 name: "test-memory".to_string(),
637 content: "# Hello\n\nThis is a test memory.".to_string(),
638 metadata: meta,
639 }
640 }
641
642 #[test]
643 fn round_trip_markdown() {
644 let original = make_memory();
645 let rendered = original.to_markdown().expect("to_markdown should not fail");
646 let parsed = Memory::from_markdown(&rendered).expect("from_markdown should not fail");
647
648 assert_eq!(original.id, parsed.id);
649 assert_eq!(original.name, parsed.name);
650 assert_eq!(original.content, parsed.content);
651 assert_eq!(original.metadata.tags, parsed.metadata.tags);
652 assert_eq!(original.metadata.scope, parsed.metadata.scope);
653 assert_eq!(
654 original.metadata.created_at.timestamp(),
655 parsed.metadata.created_at.timestamp()
656 );
657 assert_eq!(
658 original.metadata.updated_at.timestamp(),
659 parsed.metadata.updated_at.timestamp()
660 );
661 assert_eq!(original.metadata.source, parsed.metadata.source);
662 }
663
664 #[test]
665 fn round_trip_global_scope() {
666 let meta = MemoryMetadata::new(Scope::Global, vec!["global-tag".to_string()], None);
667 let mem = Memory::new("global-mem".to_string(), "Some content.".to_string(), meta);
668 let rendered = mem.to_markdown().unwrap();
669 let parsed = Memory::from_markdown(&rendered).unwrap();
670
671 assert_eq!(parsed.metadata.scope, Scope::Global);
672 assert_eq!(parsed.metadata.source, None);
673 assert_eq!(parsed.content, "Some content.");
674 }
675
676 #[test]
677 fn round_trip_no_source() {
678 let meta = MemoryMetadata::new(Scope::Project("proj".to_string()), vec![], None);
679 let mem = Memory::new("no-src".to_string(), "Body.".to_string(), meta);
680 let md = mem.to_markdown().unwrap();
681 assert!(!md.contains("source:"));
683 let parsed = Memory::from_markdown(&md).unwrap();
684 assert_eq!(parsed.metadata.source, None);
685 }
686
687 #[test]
688 fn from_markdown_missing_frontmatter_fails() {
689 let result = Memory::from_markdown("just plain text");
690 assert!(result.is_err());
691 }
692
693 #[test]
694 fn scope_dir_prefix() {
695 assert_eq!(Scope::Global.dir_prefix(), "global");
696 assert_eq!(
697 Scope::Project("foo".to_string()).dir_prefix(),
698 "projects/foo"
699 );
700 }
701
702 #[test]
703 fn scope_from_str_global() {
704 assert_eq!("global".parse::<Scope>().unwrap(), Scope::Global);
705 }
706
707 #[test]
708 fn scope_from_str_project() {
709 assert_eq!(
710 "project:my-proj".parse::<Scope>().unwrap(),
711 Scope::Project("my-proj".to_string())
712 );
713 }
714
715 #[test]
716 fn scope_from_str_empty_project_name_fails() {
717 assert!("project:".parse::<Scope>().is_err());
718 }
719
720 #[test]
721 fn scope_from_str_unknown_fails() {
722 assert!("unknown".parse::<Scope>().is_err());
723 assert!("PROJECT:foo".parse::<Scope>().is_err());
724 }
725
726 #[test]
727 fn scope_from_str_project_traversal_fails() {
728 assert!("project:../../etc".parse::<Scope>().is_err());
729 }
730
731 #[test]
734 fn validate_name_accepts_valid() {
735 assert!(validate_name("my-memory").is_ok());
736 assert!(validate_name("some_memory").is_ok());
737 assert!(validate_name("nested/path").is_ok());
738 assert!(validate_name("v1.2.3").is_ok());
739 }
740
741 #[test]
742 fn validate_name_rejects_traversal() {
743 assert!(validate_name("../../etc/passwd").is_err());
744 assert!(validate_name("..").is_err());
745 assert!(validate_name(".hidden").is_err());
746 assert!(validate_name("a/../b").is_err());
747 }
748
749 #[test]
750 fn validate_name_rejects_empty() {
751 assert!(validate_name("").is_err());
752 }
753
754 #[test]
755 fn validate_name_rejects_special_chars() {
756 assert!(validate_name("foo;bar").is_err());
757 assert!(validate_name("foo bar").is_err());
758 assert!(validate_name("foo\0bar").is_err());
759 }
760
761 #[test]
762 fn validate_name_rejects_empty_component() {
763 assert!(validate_name("foo//bar").is_err());
764 assert!(validate_name("/absolute").is_err());
765 }
766
767 #[test]
770 fn test_parse_scope_none_defaults_global() {
771 assert_eq!(parse_scope(None).unwrap(), Scope::Global);
772 }
773
774 #[test]
775 fn test_parse_scope_some_global() {
776 assert_eq!(parse_scope(Some("global")).unwrap(), Scope::Global);
777 }
778
779 #[test]
780 fn test_parse_scope_some_project() {
781 assert_eq!(
782 parse_scope(Some("project:my-proj")).unwrap(),
783 Scope::Project("my-proj".to_string())
784 );
785 }
786
787 #[test]
790 fn test_parse_qualified_name_global() {
791 let (scope, name) = parse_qualified_name("global/my-memory").unwrap();
792 assert_eq!(scope, Scope::Global);
793 assert_eq!(name, "my-memory");
794 }
795
796 #[test]
797 fn test_parse_qualified_name_project() {
798 let (scope, name) = parse_qualified_name("projects/my-project/my-memory").unwrap();
799 assert_eq!(scope, Scope::Project("my-project".to_string()));
800 assert_eq!(name, "my-memory");
801 }
802
803 #[test]
804 fn test_parse_qualified_name_nested() {
805 let (scope, name) = parse_qualified_name("projects/my-project/nested/memory").unwrap();
806 assert_eq!(scope, Scope::Project("my-project".to_string()));
807 assert_eq!(name, "nested/memory");
808 }
809
810 #[test]
813 fn validate_branch_name_accepts_valid() {
814 assert!(validate_branch_name("main").is_ok());
815 assert!(validate_branch_name("feature/foo").is_ok());
816 assert!(validate_branch_name("release-1.0").is_ok());
817 assert!(validate_branch_name("a/b/c").is_ok());
818 assert!(validate_branch_name("my-branch_v2").is_ok());
819 }
820
821 #[test]
822 fn validate_branch_name_rejects_empty() {
823 assert!(validate_branch_name("").is_err());
824 }
825
826 #[test]
827 fn validate_branch_name_rejects_dot_dot() {
828 assert!(validate_branch_name("foo..bar").is_err());
829 assert!(validate_branch_name("..").is_err());
830 }
831
832 #[test]
833 fn validate_branch_name_rejects_invalid_chars() {
834 for name in &[
835 "foo bar", "foo~bar", "foo^bar", "foo:bar", "foo?bar", "foo*bar", "foo[bar", "foo\\bar",
836 ] {
837 assert!(
838 validate_branch_name(name).is_err(),
839 "should reject: {}",
840 name
841 );
842 }
843 }
844
845 #[test]
846 fn validate_branch_name_rejects_invalid_start_end() {
847 assert!(validate_branch_name("/foo").is_err());
848 assert!(validate_branch_name("foo/").is_err());
849 assert!(validate_branch_name(".foo").is_err());
850 assert!(validate_branch_name("foo.").is_err());
851 }
852
853 #[test]
854 fn validate_branch_name_rejects_consecutive_slashes() {
855 assert!(validate_branch_name("foo//bar").is_err());
856 }
857
858 #[test]
861 fn scope_filter_none_defaults_to_global_only() {
862 assert_eq!(parse_scope_filter(None).unwrap(), ScopeFilter::GlobalOnly);
863 }
864
865 #[test]
866 fn scope_filter_global_returns_global_only() {
867 assert_eq!(
868 parse_scope_filter(Some("global")).unwrap(),
869 ScopeFilter::GlobalOnly
870 );
871 }
872
873 #[test]
874 fn scope_filter_project_returns_project_and_global() {
875 assert_eq!(
876 parse_scope_filter(Some("project:my-proj")).unwrap(),
877 ScopeFilter::ProjectAndGlobal("my-proj".to_string()),
878 );
879 }
880
881 #[test]
882 fn scope_filter_all_returns_all() {
883 assert_eq!(parse_scope_filter(Some("all")).unwrap(), ScopeFilter::All);
884 }
885
886 #[test]
887 fn scope_filter_invalid_returns_error() {
888 assert!(parse_scope_filter(Some("bogus")).is_err());
889 }
890}