1use chrono::{DateTime, Utc};
2use rmcp::schemars;
3use serde::{Deserialize, Serialize};
4use std::{fmt, str::FromStr};
5use uuid::Uuid;
6
7use crate::error::MemoryError;
8
9pub fn validate_name(name: &str) -> Result<(), MemoryError> {
19 if name.is_empty() {
20 return Err(MemoryError::InvalidInput {
21 reason: "name must not be empty".to_string(),
22 });
23 }
24
25 let components: Vec<&str> = name.split('/').collect();
26
27 if components.len() > 3 {
28 return Err(MemoryError::InvalidInput {
29 reason: format!("name '{}' exceeds maximum nesting depth of 3", name),
30 });
31 }
32
33 for component in &components {
34 if component.is_empty() {
35 return Err(MemoryError::InvalidInput {
36 reason: format!("name '{}' contains an empty path component", name),
37 });
38 }
39 if component.starts_with('.') {
40 return Err(MemoryError::InvalidInput {
41 reason: format!(
42 "name '{}' contains a dot-prefixed component '{}'",
43 name, component
44 ),
45 });
46 }
47 if !component
48 .chars()
49 .all(|c| c.is_alphanumeric() || c == '-' || c == '_' || c == '.')
50 {
51 return Err(MemoryError::InvalidInput {
52 reason: format!(
53 "name '{}' contains disallowed characters in component '{}'",
54 name, component
55 ),
56 });
57 }
58 }
59
60 Ok(())
61}
62
63pub fn validate_branch_name(branch: &str) -> Result<(), MemoryError> {
68 if branch.is_empty() {
69 return Err(MemoryError::InvalidInput {
70 reason: "branch name cannot be empty".into(),
71 });
72 }
73 if branch.contains("..") {
74 return Err(MemoryError::InvalidInput {
75 reason: "branch name cannot contain '..'".into(),
76 });
77 }
78 let invalid_chars = [' ', '~', '^', ':', '?', '*', '[', '\\'];
79 for c in branch.chars() {
80 if c.is_ascii_control() || invalid_chars.contains(&c) {
81 return Err(MemoryError::InvalidInput {
82 reason: format!("branch name contains invalid character '{}'", c),
83 });
84 }
85 }
86 if branch.starts_with('/')
87 || branch.ends_with('/')
88 || branch.ends_with('.')
89 || branch.starts_with('.')
90 {
91 return Err(MemoryError::InvalidInput {
92 reason: "branch name has invalid start/end character".into(),
93 });
94 }
95 if branch.contains("//") {
96 return Err(MemoryError::InvalidInput {
97 reason: "branch name contains consecutive slashes".into(),
98 });
99 }
100 Ok(())
101}
102
103#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
112#[serde(tag = "type", content = "name")]
113#[non_exhaustive]
114pub enum Scope {
115 Global,
117 Project(String),
119}
120
121impl Scope {
122 pub fn dir_prefix(&self) -> String {
124 match self {
125 Scope::Global => "global".to_string(),
126 Scope::Project(name) => format!("projects/{}", name),
127 }
128 }
129}
130
131impl fmt::Display for Scope {
132 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
133 match self {
134 Scope::Global => write!(f, "global"),
135 Scope::Project(name) => write!(f, "project:{}", name),
136 }
137 }
138}
139
140impl FromStr for Scope {
141 type Err = MemoryError;
142
143 fn from_str(s: &str) -> Result<Self, Self::Err> {
147 if s == "global" {
148 return Ok(Scope::Global);
149 }
150 if let Some(name) = s.strip_prefix("project:") {
151 if name.is_empty() {
152 return Err(MemoryError::InvalidInput {
153 reason: "project scope requires a non-empty name after 'project:'".to_string(),
154 });
155 }
156 if name.contains('/') {
157 return Err(MemoryError::InvalidInput {
158 reason: "project name must not contain '/'".to_string(),
159 });
160 }
161 validate_name(name)?;
162 return Ok(Scope::Project(name.to_string()));
163 }
164 Err(MemoryError::InvalidInput {
165 reason: format!(
166 "unrecognised scope '{}'; expected 'global' or 'project:<name>'",
167 s
168 ),
169 })
170 }
171}
172
173#[derive(Debug, Clone, Serialize, Deserialize)]
179pub struct MemoryMetadata {
180 pub tags: Vec<String>,
182 pub scope: Scope,
184 pub created_at: DateTime<Utc>,
186 pub updated_at: DateTime<Utc>,
188 pub source: Option<String>,
190}
191
192impl MemoryMetadata {
193 pub fn new(scope: Scope, tags: Vec<String>, source: Option<String>) -> Self {
195 let now = Utc::now();
196 Self {
197 tags,
198 scope,
199 created_at: now,
200 updated_at: now,
201 source,
202 }
203 }
204}
205
206#[derive(Debug, Clone, Serialize, Deserialize)]
212pub struct Memory {
213 pub id: String,
215 pub name: String,
217 pub content: String,
219 pub metadata: MemoryMetadata,
221}
222
223impl Memory {
224 pub fn new(name: String, content: String, metadata: MemoryMetadata) -> Self {
226 Self {
227 id: Uuid::new_v4().to_string(),
228 name,
229 content,
230 metadata,
231 }
232 }
233
234 pub fn to_markdown(&self) -> Result<String, MemoryError> {
245 #[derive(Serialize)]
246 struct Frontmatter<'a> {
247 id: &'a str,
248 name: &'a str,
249 tags: &'a [String],
250 scope: &'a Scope,
251 created_at: &'a DateTime<Utc>,
252 updated_at: &'a DateTime<Utc>,
253 #[serde(skip_serializing_if = "Option::is_none")]
254 source: Option<&'a str>,
255 }
256
257 let fm = Frontmatter {
258 id: &self.id,
259 name: &self.name,
260 tags: &self.metadata.tags,
261 scope: &self.metadata.scope,
262 created_at: &self.metadata.created_at,
263 updated_at: &self.metadata.updated_at,
264 source: self.metadata.source.as_deref(),
265 };
266
267 let yaml = serde_yaml::to_string(&fm)?;
268 Ok(format!("---\n{}---\n\n{}", yaml, self.content))
269 }
270
271 pub fn from_markdown(raw: &str) -> Result<Self, MemoryError> {
273 let rest = raw
275 .strip_prefix("---\n")
276 .ok_or_else(|| MemoryError::InvalidInput {
277 reason: "missing opening frontmatter delimiter".to_string(),
278 })?;
279
280 let end_marker = rest
282 .find("\n---\n")
283 .ok_or_else(|| MemoryError::InvalidInput {
284 reason: "missing closing frontmatter delimiter".to_string(),
285 })?;
286
287 let yaml_str = &rest[..end_marker];
288 let body = rest[end_marker + 5..].trim_start_matches('\n');
290
291 #[derive(Deserialize)]
292 struct Frontmatter {
293 id: String,
294 name: String,
295 tags: Vec<String>,
296 scope: Scope,
297 created_at: DateTime<Utc>,
298 updated_at: DateTime<Utc>,
299 source: Option<String>,
300 }
301
302 let fm: Frontmatter = serde_yaml::from_str(yaml_str)?;
303
304 Ok(Memory {
305 id: fm.id,
306 name: fm.name,
307 content: body.to_string(),
308 metadata: MemoryMetadata {
309 tags: fm.tags,
310 scope: fm.scope,
311 created_at: fm.created_at,
312 updated_at: fm.updated_at,
313 source: fm.source,
314 },
315 })
316 }
317}
318
319#[derive(Debug, Clone, PartialEq, Eq)]
328pub enum ScopeFilter {
329 GlobalOnly,
331 ProjectAndGlobal(String),
333 All,
335}
336
337pub fn parse_scope_filter(scope: Option<&str>) -> Result<ScopeFilter, MemoryError> {
346 match scope {
347 None | Some("global") => Ok(ScopeFilter::GlobalOnly),
348 Some("all") => Ok(ScopeFilter::All),
349 Some(s) => {
350 let parsed = s.parse::<Scope>()?;
351 match parsed {
352 Scope::Project(name) => Ok(ScopeFilter::ProjectAndGlobal(name)),
353 Scope::Global => Ok(ScopeFilter::GlobalOnly),
356 }
357 }
358 }
359}
360
361pub fn parse_scope(scope: Option<&str>) -> Result<Scope, MemoryError> {
367 match scope {
368 None => Ok(Scope::Global),
369 Some(s) => s.parse::<Scope>(),
370 }
371}
372
373pub fn parse_qualified_name(qualified: &str) -> Result<(Scope, String), MemoryError> {
376 if let Some(rest) = qualified.strip_prefix("global/") {
377 validate_name(rest)?;
378 return Ok((Scope::Global, rest.to_string()));
379 }
380 if let Some(rest) = qualified.strip_prefix("projects/") {
381 if let Some(slash_pos) = rest.find('/') {
383 let project = &rest[..slash_pos];
384 let name = &rest[slash_pos + 1..];
385 if project.is_empty() || name.is_empty() {
386 return Err(MemoryError::InvalidInput {
387 reason: format!(
388 "malformed qualified name '{}': project or memory name is empty",
389 qualified
390 ),
391 });
392 }
393 validate_name(project)?;
394 validate_name(name)?;
395 return Ok((Scope::Project(project.to_string()), name.to_string()));
396 }
397 return Err(MemoryError::InvalidInput {
398 reason: format!(
399 "malformed qualified name '{}': missing memory name after project",
400 qualified
401 ),
402 });
403 }
404 Err(MemoryError::InvalidInput {
405 reason: format!(
406 "malformed qualified name '{}': must start with 'global/' or 'projects/'",
407 qualified
408 ),
409 })
410}
411
412#[derive(Debug, Deserialize, schemars::JsonSchema)]
418pub struct RememberArgs {
419 pub content: String,
421 pub name: String,
423 #[serde(default)]
425 pub tags: Vec<String>,
426 #[serde(default)]
428 pub scope: Option<String>,
429 #[serde(default)]
431 pub source: Option<String>,
432}
433
434#[derive(Debug, Deserialize, schemars::JsonSchema)]
436pub struct RecallArgs {
437 pub query: String,
439 #[serde(default)]
441 pub scope: Option<String>,
442 #[serde(default)]
444 pub limit: Option<usize>,
445}
446
447#[derive(Debug, Deserialize, schemars::JsonSchema)]
449pub struct ForgetArgs {
450 pub name: String,
452 #[serde(default)]
454 pub scope: Option<String>,
455}
456
457#[derive(Debug, Deserialize, schemars::JsonSchema)]
459pub struct EditArgs {
460 pub name: String,
462 #[serde(default)]
464 pub content: Option<String>,
465 #[serde(default)]
467 pub tags: Option<Vec<String>>,
468 #[serde(default)]
470 pub scope: Option<String>,
471}
472
473#[derive(Debug, Deserialize, schemars::JsonSchema)]
475pub struct ListArgs {
476 #[serde(default)]
478 pub scope: Option<String>,
479}
480
481#[derive(Debug, Deserialize, schemars::JsonSchema)]
483pub struct ReadArgs {
484 pub name: String,
486 #[serde(default)]
488 pub scope: Option<String>,
489}
490
491#[derive(Debug, Deserialize, schemars::JsonSchema)]
493pub struct SyncArgs {
494 #[serde(default)]
496 pub pull_first: Option<bool>,
497}
498
499#[derive(Debug)]
505#[non_exhaustive]
506pub enum PullResult {
507 NoRemote,
509 UpToDate,
511 FastForward {
513 old_head: [u8; 20],
515 new_head: [u8; 20],
517 },
518 Merged {
520 conflicts_resolved: usize,
522 old_head: [u8; 20],
524 new_head: [u8; 20],
526 },
527}
528
529#[derive(Debug, Default)]
535pub struct ChangedMemories {
536 pub upserted: Vec<String>,
538 pub removed: Vec<String>,
540}
541
542impl ChangedMemories {
543 pub fn is_empty(&self) -> bool {
545 self.upserted.is_empty() && self.removed.is_empty()
546 }
547}
548
549#[derive(Debug, Default)]
555pub struct ReindexStats {
556 pub added: usize,
558 pub updated: usize,
560 pub removed: usize,
562 pub errors: usize,
564}
565
566use std::sync::Arc;
571
572use crate::{
573 auth::AuthProvider, embedding::EmbeddingBackend, index::ScopedIndex, repo::MemoryRepo,
574};
575
576#[non_exhaustive]
581pub struct AppState {
582 pub repo: Arc<MemoryRepo>,
584 pub embedding: Box<dyn EmbeddingBackend>,
586 pub index: ScopedIndex,
588 pub auth: AuthProvider,
590 pub branch: String,
592}
593
594impl AppState {
595 pub fn new(
597 repo: Arc<MemoryRepo>,
598 branch: String,
599 embedding: Box<dyn EmbeddingBackend>,
600 index: ScopedIndex,
601 auth: AuthProvider,
602 ) -> Self {
603 Self {
604 repo,
605 embedding,
606 index,
607 auth,
608 branch,
609 }
610 }
611}
612
613#[cfg(test)]
618mod tests {
619 use super::*;
620
621 fn make_memory() -> Memory {
622 let meta = MemoryMetadata {
623 tags: vec!["test".to_string(), "round-trip".to_string()],
624 scope: Scope::Project("my-project".to_string()),
625 created_at: DateTime::from_timestamp(1_700_000_000, 0).unwrap(),
626 updated_at: DateTime::from_timestamp(1_700_000_100, 0).unwrap(),
627 source: Some("unit-test".to_string()),
628 };
629 Memory {
630 id: "550e8400-e29b-41d4-a716-446655440000".to_string(),
631 name: "test-memory".to_string(),
632 content: "# Hello\n\nThis is a test memory.".to_string(),
633 metadata: meta,
634 }
635 }
636
637 #[test]
638 fn round_trip_markdown() {
639 let original = make_memory();
640 let rendered = original.to_markdown().expect("to_markdown should not fail");
641 let parsed = Memory::from_markdown(&rendered).expect("from_markdown should not fail");
642
643 assert_eq!(original.id, parsed.id);
644 assert_eq!(original.name, parsed.name);
645 assert_eq!(original.content, parsed.content);
646 assert_eq!(original.metadata.tags, parsed.metadata.tags);
647 assert_eq!(original.metadata.scope, parsed.metadata.scope);
648 assert_eq!(
649 original.metadata.created_at.timestamp(),
650 parsed.metadata.created_at.timestamp()
651 );
652 assert_eq!(
653 original.metadata.updated_at.timestamp(),
654 parsed.metadata.updated_at.timestamp()
655 );
656 assert_eq!(original.metadata.source, parsed.metadata.source);
657 }
658
659 #[test]
660 fn round_trip_global_scope() {
661 let meta = MemoryMetadata::new(Scope::Global, vec!["global-tag".to_string()], None);
662 let mem = Memory::new("global-mem".to_string(), "Some content.".to_string(), meta);
663 let rendered = mem.to_markdown().unwrap();
664 let parsed = Memory::from_markdown(&rendered).unwrap();
665
666 assert_eq!(parsed.metadata.scope, Scope::Global);
667 assert_eq!(parsed.metadata.source, None);
668 assert_eq!(parsed.content, "Some content.");
669 }
670
671 #[test]
672 fn round_trip_no_source() {
673 let meta = MemoryMetadata::new(Scope::Project("proj".to_string()), vec![], None);
674 let mem = Memory::new("no-src".to_string(), "Body.".to_string(), meta);
675 let md = mem.to_markdown().unwrap();
676 assert!(!md.contains("source:"));
678 let parsed = Memory::from_markdown(&md).unwrap();
679 assert_eq!(parsed.metadata.source, None);
680 }
681
682 #[test]
683 fn from_markdown_missing_frontmatter_fails() {
684 let result = Memory::from_markdown("just plain text");
685 assert!(result.is_err());
686 }
687
688 #[test]
689 fn scope_dir_prefix() {
690 assert_eq!(Scope::Global.dir_prefix(), "global");
691 assert_eq!(
692 Scope::Project("foo".to_string()).dir_prefix(),
693 "projects/foo"
694 );
695 }
696
697 #[test]
698 fn scope_from_str_global() {
699 assert_eq!("global".parse::<Scope>().unwrap(), Scope::Global);
700 }
701
702 #[test]
703 fn scope_from_str_project() {
704 assert_eq!(
705 "project:my-proj".parse::<Scope>().unwrap(),
706 Scope::Project("my-proj".to_string())
707 );
708 }
709
710 #[test]
711 fn scope_from_str_empty_project_name_fails() {
712 assert!("project:".parse::<Scope>().is_err());
713 }
714
715 #[test]
716 fn scope_from_str_unknown_fails() {
717 assert!("unknown".parse::<Scope>().is_err());
718 assert!("PROJECT:foo".parse::<Scope>().is_err());
719 }
720
721 #[test]
722 fn scope_from_str_project_traversal_fails() {
723 assert!("project:../../etc".parse::<Scope>().is_err());
724 }
725
726 #[test]
729 fn validate_name_accepts_valid() {
730 assert!(validate_name("my-memory").is_ok());
731 assert!(validate_name("some_memory").is_ok());
732 assert!(validate_name("nested/path").is_ok());
733 assert!(validate_name("v1.2.3").is_ok());
734 }
735
736 #[test]
737 fn validate_name_rejects_traversal() {
738 assert!(validate_name("../../etc/passwd").is_err());
739 assert!(validate_name("..").is_err());
740 assert!(validate_name(".hidden").is_err());
741 assert!(validate_name("a/../b").is_err());
742 }
743
744 #[test]
745 fn validate_name_rejects_empty() {
746 assert!(validate_name("").is_err());
747 }
748
749 #[test]
750 fn validate_name_rejects_special_chars() {
751 assert!(validate_name("foo;bar").is_err());
752 assert!(validate_name("foo bar").is_err());
753 assert!(validate_name("foo\0bar").is_err());
754 }
755
756 #[test]
757 fn validate_name_rejects_empty_component() {
758 assert!(validate_name("foo//bar").is_err());
759 assert!(validate_name("/absolute").is_err());
760 }
761
762 #[test]
765 fn test_parse_scope_none_defaults_global() {
766 assert_eq!(parse_scope(None).unwrap(), Scope::Global);
767 }
768
769 #[test]
770 fn test_parse_scope_some_global() {
771 assert_eq!(parse_scope(Some("global")).unwrap(), Scope::Global);
772 }
773
774 #[test]
775 fn test_parse_scope_some_project() {
776 assert_eq!(
777 parse_scope(Some("project:my-proj")).unwrap(),
778 Scope::Project("my-proj".to_string())
779 );
780 }
781
782 #[test]
785 fn test_parse_qualified_name_global() {
786 let (scope, name) = parse_qualified_name("global/my-memory").unwrap();
787 assert_eq!(scope, Scope::Global);
788 assert_eq!(name, "my-memory");
789 }
790
791 #[test]
792 fn test_parse_qualified_name_project() {
793 let (scope, name) = parse_qualified_name("projects/my-project/my-memory").unwrap();
794 assert_eq!(scope, Scope::Project("my-project".to_string()));
795 assert_eq!(name, "my-memory");
796 }
797
798 #[test]
799 fn test_parse_qualified_name_nested() {
800 let (scope, name) = parse_qualified_name("projects/my-project/nested/memory").unwrap();
801 assert_eq!(scope, Scope::Project("my-project".to_string()));
802 assert_eq!(name, "nested/memory");
803 }
804
805 #[test]
808 fn validate_branch_name_accepts_valid() {
809 assert!(validate_branch_name("main").is_ok());
810 assert!(validate_branch_name("feature/foo").is_ok());
811 assert!(validate_branch_name("release-1.0").is_ok());
812 assert!(validate_branch_name("a/b/c").is_ok());
813 assert!(validate_branch_name("my-branch_v2").is_ok());
814 }
815
816 #[test]
817 fn validate_branch_name_rejects_empty() {
818 assert!(validate_branch_name("").is_err());
819 }
820
821 #[test]
822 fn validate_branch_name_rejects_dot_dot() {
823 assert!(validate_branch_name("foo..bar").is_err());
824 assert!(validate_branch_name("..").is_err());
825 }
826
827 #[test]
828 fn validate_branch_name_rejects_invalid_chars() {
829 for name in &[
830 "foo bar", "foo~bar", "foo^bar", "foo:bar", "foo?bar", "foo*bar", "foo[bar", "foo\\bar",
831 ] {
832 assert!(
833 validate_branch_name(name).is_err(),
834 "should reject: {}",
835 name
836 );
837 }
838 }
839
840 #[test]
841 fn validate_branch_name_rejects_invalid_start_end() {
842 assert!(validate_branch_name("/foo").is_err());
843 assert!(validate_branch_name("foo/").is_err());
844 assert!(validate_branch_name(".foo").is_err());
845 assert!(validate_branch_name("foo.").is_err());
846 }
847
848 #[test]
849 fn validate_branch_name_rejects_consecutive_slashes() {
850 assert!(validate_branch_name("foo//bar").is_err());
851 }
852
853 #[test]
856 fn scope_filter_none_defaults_to_global_only() {
857 assert_eq!(parse_scope_filter(None).unwrap(), ScopeFilter::GlobalOnly);
858 }
859
860 #[test]
861 fn scope_filter_global_returns_global_only() {
862 assert_eq!(
863 parse_scope_filter(Some("global")).unwrap(),
864 ScopeFilter::GlobalOnly
865 );
866 }
867
868 #[test]
869 fn scope_filter_project_returns_project_and_global() {
870 assert_eq!(
871 parse_scope_filter(Some("project:my-proj")).unwrap(),
872 ScopeFilter::ProjectAndGlobal("my-proj".to_string()),
873 );
874 }
875
876 #[test]
877 fn scope_filter_all_returns_all() {
878 assert_eq!(parse_scope_filter(Some("all")).unwrap(), ScopeFilter::All);
879 }
880
881 #[test]
882 fn scope_filter_invalid_returns_error() {
883 assert!(parse_scope_filter(Some("bogus")).is_err());
884 }
885}