use chrono::{DateTime, Utc};
use rmcp::schemars;
use serde::{Deserialize, Serialize};
use std::{fmt, str::FromStr};
use uuid::Uuid;
use crate::error::MemoryError;
pub fn validate_name(name: &str) -> Result<(), MemoryError> {
if name.is_empty() {
return Err(MemoryError::InvalidInput {
reason: "name must not be empty".to_string(),
});
}
let components: Vec<&str> = name.split('/').collect();
if components.len() > 3 {
return Err(MemoryError::InvalidInput {
reason: format!("name '{}' exceeds maximum nesting depth of 3", name),
});
}
for component in &components {
if component.is_empty() {
return Err(MemoryError::InvalidInput {
reason: format!("name '{}' contains an empty path component", name),
});
}
if component.starts_with('.') {
return Err(MemoryError::InvalidInput {
reason: format!(
"name '{}' contains a dot-prefixed component '{}'",
name, component
),
});
}
if !component
.chars()
.all(|c| c.is_alphanumeric() || c == '-' || c == '_' || c == '.')
{
return Err(MemoryError::InvalidInput {
reason: format!(
"name '{}' contains disallowed characters in component '{}'",
name, component
),
});
}
}
Ok(())
}
pub fn validate_branch_name(branch: &str) -> Result<(), MemoryError> {
if branch.is_empty() {
return Err(MemoryError::InvalidInput {
reason: "branch name cannot be empty".into(),
});
}
if branch.contains("..") {
return Err(MemoryError::InvalidInput {
reason: "branch name cannot contain '..'".into(),
});
}
let invalid_chars = [' ', '~', '^', ':', '?', '*', '[', '\\'];
for c in branch.chars() {
if c.is_ascii_control() || invalid_chars.contains(&c) {
return Err(MemoryError::InvalidInput {
reason: format!("branch name contains invalid character '{}'", c),
});
}
}
if branch.starts_with('/')
|| branch.ends_with('/')
|| branch.ends_with('.')
|| branch.starts_with('.')
{
return Err(MemoryError::InvalidInput {
reason: "branch name has invalid start/end character".into(),
});
}
if branch.contains("//") {
return Err(MemoryError::InvalidInput {
reason: "branch name contains consecutive slashes".into(),
});
}
Ok(())
}
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(tag = "type", content = "name")]
#[non_exhaustive]
pub enum Scope {
Global,
Project(String),
}
impl Scope {
pub fn dir_prefix(&self) -> String {
match self {
Scope::Global => "global".to_string(),
Scope::Project(name) => format!("projects/{}", name),
}
}
}
impl fmt::Display for Scope {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Scope::Global => write!(f, "global"),
Scope::Project(name) => write!(f, "project:{}", name),
}
}
}
impl FromStr for Scope {
type Err = MemoryError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
if s == "global" {
return Ok(Scope::Global);
}
if let Some(name) = s.strip_prefix("project:") {
if name.is_empty() {
return Err(MemoryError::InvalidInput {
reason: "project scope requires a non-empty name after 'project:'".to_string(),
});
}
if name.contains('/') {
return Err(MemoryError::InvalidInput {
reason: "project name must not contain '/'".to_string(),
});
}
validate_name(name)?;
return Ok(Scope::Project(name.to_string()));
}
Err(MemoryError::InvalidInput {
reason: format!(
"unrecognised scope '{}'; expected 'global' or 'project:<name>'",
s
),
})
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MemoryMetadata {
pub tags: Vec<String>,
pub scope: Scope,
pub created_at: DateTime<Utc>,
pub updated_at: DateTime<Utc>,
pub source: Option<String>,
}
impl MemoryMetadata {
pub fn new(scope: Scope, tags: Vec<String>, source: Option<String>) -> Self {
let now = Utc::now();
Self {
tags,
scope,
created_at: now,
updated_at: now,
source,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Memory {
pub id: String,
pub name: String,
pub content: String,
pub metadata: MemoryMetadata,
}
impl Memory {
pub fn new(name: String, content: String, metadata: MemoryMetadata) -> Self {
Self {
id: Uuid::new_v4().to_string(),
name,
content,
metadata,
}
}
pub fn to_markdown(&self) -> Result<String, MemoryError> {
#[derive(Serialize)]
struct Frontmatter<'a> {
id: &'a str,
name: &'a str,
tags: &'a [String],
scope: &'a Scope,
created_at: &'a DateTime<Utc>,
updated_at: &'a DateTime<Utc>,
#[serde(skip_serializing_if = "Option::is_none")]
source: Option<&'a str>,
}
let fm = Frontmatter {
id: &self.id,
name: &self.name,
tags: &self.metadata.tags,
scope: &self.metadata.scope,
created_at: &self.metadata.created_at,
updated_at: &self.metadata.updated_at,
source: self.metadata.source.as_deref(),
};
let yaml = serde_yaml_ng::to_string(&fm)?;
Ok(format!("---\n{}---\n\n{}", yaml, self.content))
}
pub fn from_markdown(raw: &str) -> Result<Self, MemoryError> {
let rest = raw
.strip_prefix("---\n")
.ok_or_else(|| MemoryError::InvalidInput {
reason: "missing opening frontmatter delimiter".to_string(),
})?;
let end_marker = rest
.find("\n---\n")
.ok_or_else(|| MemoryError::InvalidInput {
reason: "missing closing frontmatter delimiter".to_string(),
})?;
let yaml_str = &rest[..end_marker];
let body = rest[end_marker + 5..].trim_start_matches('\n');
#[derive(Deserialize)]
struct Frontmatter {
id: String,
name: String,
tags: Vec<String>,
scope: Scope,
created_at: DateTime<Utc>,
updated_at: DateTime<Utc>,
source: Option<String>,
}
let fm: Frontmatter = serde_yaml_ng::from_str(yaml_str)?;
Ok(Memory {
id: fm.id,
name: fm.name,
content: body.to_string(),
metadata: MemoryMetadata {
tags: fm.tags,
scope: fm.scope,
created_at: fm.created_at,
updated_at: fm.updated_at,
source: fm.source,
},
})
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum ScopeFilter {
GlobalOnly,
ProjectAndGlobal(String),
All,
}
pub fn parse_scope_filter(scope: Option<&str>) -> Result<ScopeFilter, MemoryError> {
match scope {
None | Some("global") => Ok(ScopeFilter::GlobalOnly),
Some("all") => Ok(ScopeFilter::All),
Some(s) => {
let parsed = s.parse::<Scope>()?;
match parsed {
Scope::Project(name) => Ok(ScopeFilter::ProjectAndGlobal(name)),
Scope::Global => Ok(ScopeFilter::GlobalOnly),
}
}
}
}
pub fn parse_scope(scope: Option<&str>) -> Result<Scope, MemoryError> {
match scope {
None => Ok(Scope::Global),
Some(s) => s.parse::<Scope>(),
}
}
pub fn parse_qualified_name(qualified: &str) -> Result<(Scope, String), MemoryError> {
if let Some(rest) = qualified.strip_prefix("global/") {
validate_name(rest)?;
return Ok((Scope::Global, rest.to_string()));
}
if let Some(rest) = qualified.strip_prefix("projects/") {
if let Some(slash_pos) = rest.find('/') {
let project = &rest[..slash_pos];
let name = &rest[slash_pos + 1..];
if project.is_empty() || name.is_empty() {
return Err(MemoryError::InvalidInput {
reason: format!(
"malformed qualified name '{}': project or memory name is empty",
qualified
),
});
}
validate_name(project)?;
validate_name(name)?;
return Ok((Scope::Project(project.to_string()), name.to_string()));
}
return Err(MemoryError::InvalidInput {
reason: format!(
"malformed qualified name '{}': missing memory name after project",
qualified
),
});
}
Err(MemoryError::InvalidInput {
reason: format!(
"malformed qualified name '{}': must start with 'global/' or 'projects/'",
qualified
),
})
}
#[derive(Debug, Deserialize, schemars::JsonSchema)]
pub struct RememberArgs {
pub content: String,
pub name: String,
#[serde(default)]
pub tags: Vec<String>,
#[serde(default)]
pub scope: Option<String>,
#[serde(default)]
pub source: Option<String>,
}
#[derive(Debug, Deserialize, schemars::JsonSchema)]
pub struct RecallArgs {
pub query: String,
#[serde(default)]
pub scope: Option<String>,
#[serde(default)]
pub limit: Option<usize>,
}
#[derive(Debug, Deserialize, schemars::JsonSchema)]
pub struct ForgetArgs {
pub name: String,
#[serde(default)]
pub scope: Option<String>,
}
#[derive(Debug, Deserialize, schemars::JsonSchema)]
pub struct EditArgs {
pub name: String,
#[serde(default)]
pub content: Option<String>,
#[serde(default)]
pub tags: Option<Vec<String>>,
#[serde(default)]
pub scope: Option<String>,
}
#[derive(Debug, Deserialize, schemars::JsonSchema)]
pub struct ListArgs {
#[serde(default)]
pub scope: Option<String>,
}
#[derive(Debug, Deserialize, schemars::JsonSchema)]
pub struct ReadArgs {
pub name: String,
#[serde(default)]
pub scope: Option<String>,
}
#[derive(Debug, Deserialize, schemars::JsonSchema)]
pub struct SyncArgs {
#[serde(default)]
pub pull_first: Option<bool>,
}
#[derive(Debug)]
#[non_exhaustive]
pub enum PullResult {
NoRemote,
UpToDate,
FastForward {
old_head: [u8; 20],
new_head: [u8; 20],
},
Merged {
conflicts_resolved: usize,
old_head: [u8; 20],
new_head: [u8; 20],
},
}
#[derive(Debug, Default)]
pub struct ChangedMemories {
pub upserted: Vec<String>,
pub removed: Vec<String>,
}
impl ChangedMemories {
pub fn is_empty(&self) -> bool {
self.upserted.is_empty() && self.removed.is_empty()
}
}
#[derive(Debug, Default)]
pub struct ReindexStats {
pub added: usize,
pub updated: usize,
pub removed: usize,
pub errors: usize,
}
use std::sync::Arc;
use crate::{
auth::AuthProvider, embedding::EmbeddingBackend, index::ScopedIndex, repo::MemoryRepo,
};
#[non_exhaustive]
pub struct AppState {
pub repo: Arc<MemoryRepo>,
pub embedding: Box<dyn EmbeddingBackend>,
pub index: ScopedIndex,
pub auth: AuthProvider,
pub branch: String,
}
impl AppState {
pub fn new(
repo: Arc<MemoryRepo>,
branch: String,
embedding: Box<dyn EmbeddingBackend>,
index: ScopedIndex,
auth: AuthProvider,
) -> Self {
Self {
repo,
embedding,
index,
auth,
branch,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_memory() -> Memory {
let meta = MemoryMetadata {
tags: vec!["test".to_string(), "round-trip".to_string()],
scope: Scope::Project("my-project".to_string()),
created_at: DateTime::from_timestamp(1_700_000_000, 0).unwrap(),
updated_at: DateTime::from_timestamp(1_700_000_100, 0).unwrap(),
source: Some("unit-test".to_string()),
};
Memory {
id: "550e8400-e29b-41d4-a716-446655440000".to_string(),
name: "test-memory".to_string(),
content: "# Hello\n\nThis is a test memory.".to_string(),
metadata: meta,
}
}
#[test]
fn round_trip_markdown() {
let original = make_memory();
let rendered = original.to_markdown().expect("to_markdown should not fail");
let parsed = Memory::from_markdown(&rendered).expect("from_markdown should not fail");
assert_eq!(original.id, parsed.id);
assert_eq!(original.name, parsed.name);
assert_eq!(original.content, parsed.content);
assert_eq!(original.metadata.tags, parsed.metadata.tags);
assert_eq!(original.metadata.scope, parsed.metadata.scope);
assert_eq!(
original.metadata.created_at.timestamp(),
parsed.metadata.created_at.timestamp()
);
assert_eq!(
original.metadata.updated_at.timestamp(),
parsed.metadata.updated_at.timestamp()
);
assert_eq!(original.metadata.source, parsed.metadata.source);
}
#[test]
fn round_trip_global_scope() {
let meta = MemoryMetadata::new(Scope::Global, vec!["global-tag".to_string()], None);
let mem = Memory::new("global-mem".to_string(), "Some content.".to_string(), meta);
let rendered = mem.to_markdown().unwrap();
let parsed = Memory::from_markdown(&rendered).unwrap();
assert_eq!(parsed.metadata.scope, Scope::Global);
assert_eq!(parsed.metadata.source, None);
assert_eq!(parsed.content, "Some content.");
}
#[test]
fn round_trip_no_source() {
let meta = MemoryMetadata::new(Scope::Project("proj".to_string()), vec![], None);
let mem = Memory::new("no-src".to_string(), "Body.".to_string(), meta);
let md = mem.to_markdown().unwrap();
assert!(!md.contains("source:"));
let parsed = Memory::from_markdown(&md).unwrap();
assert_eq!(parsed.metadata.source, None);
}
#[test]
fn from_markdown_missing_frontmatter_fails() {
let result = Memory::from_markdown("just plain text");
assert!(result.is_err());
}
#[test]
fn scope_dir_prefix() {
assert_eq!(Scope::Global.dir_prefix(), "global");
assert_eq!(
Scope::Project("foo".to_string()).dir_prefix(),
"projects/foo"
);
}
#[test]
fn scope_from_str_global() {
assert_eq!("global".parse::<Scope>().unwrap(), Scope::Global);
}
#[test]
fn scope_from_str_project() {
assert_eq!(
"project:my-proj".parse::<Scope>().unwrap(),
Scope::Project("my-proj".to_string())
);
}
#[test]
fn scope_from_str_empty_project_name_fails() {
assert!("project:".parse::<Scope>().is_err());
}
#[test]
fn scope_from_str_unknown_fails() {
assert!("unknown".parse::<Scope>().is_err());
assert!("PROJECT:foo".parse::<Scope>().is_err());
}
#[test]
fn scope_from_str_project_traversal_fails() {
assert!("project:../../etc".parse::<Scope>().is_err());
}
#[test]
fn validate_name_accepts_valid() {
assert!(validate_name("my-memory").is_ok());
assert!(validate_name("some_memory").is_ok());
assert!(validate_name("nested/path").is_ok());
assert!(validate_name("v1.2.3").is_ok());
}
#[test]
fn validate_name_rejects_traversal() {
assert!(validate_name("../../etc/passwd").is_err());
assert!(validate_name("..").is_err());
assert!(validate_name(".hidden").is_err());
assert!(validate_name("a/../b").is_err());
}
#[test]
fn validate_name_rejects_empty() {
assert!(validate_name("").is_err());
}
#[test]
fn validate_name_rejects_special_chars() {
assert!(validate_name("foo;bar").is_err());
assert!(validate_name("foo bar").is_err());
assert!(validate_name("foo\0bar").is_err());
}
#[test]
fn validate_name_rejects_empty_component() {
assert!(validate_name("foo//bar").is_err());
assert!(validate_name("/absolute").is_err());
}
#[test]
fn test_parse_scope_none_defaults_global() {
assert_eq!(parse_scope(None).unwrap(), Scope::Global);
}
#[test]
fn test_parse_scope_some_global() {
assert_eq!(parse_scope(Some("global")).unwrap(), Scope::Global);
}
#[test]
fn test_parse_scope_some_project() {
assert_eq!(
parse_scope(Some("project:my-proj")).unwrap(),
Scope::Project("my-proj".to_string())
);
}
#[test]
fn test_parse_qualified_name_global() {
let (scope, name) = parse_qualified_name("global/my-memory").unwrap();
assert_eq!(scope, Scope::Global);
assert_eq!(name, "my-memory");
}
#[test]
fn test_parse_qualified_name_project() {
let (scope, name) = parse_qualified_name("projects/my-project/my-memory").unwrap();
assert_eq!(scope, Scope::Project("my-project".to_string()));
assert_eq!(name, "my-memory");
}
#[test]
fn test_parse_qualified_name_nested() {
let (scope, name) = parse_qualified_name("projects/my-project/nested/memory").unwrap();
assert_eq!(scope, Scope::Project("my-project".to_string()));
assert_eq!(name, "nested/memory");
}
#[test]
fn validate_branch_name_accepts_valid() {
assert!(validate_branch_name("main").is_ok());
assert!(validate_branch_name("feature/foo").is_ok());
assert!(validate_branch_name("release-1.0").is_ok());
assert!(validate_branch_name("a/b/c").is_ok());
assert!(validate_branch_name("my-branch_v2").is_ok());
}
#[test]
fn validate_branch_name_rejects_empty() {
assert!(validate_branch_name("").is_err());
}
#[test]
fn validate_branch_name_rejects_dot_dot() {
assert!(validate_branch_name("foo..bar").is_err());
assert!(validate_branch_name("..").is_err());
}
#[test]
fn validate_branch_name_rejects_invalid_chars() {
for name in &[
"foo bar", "foo~bar", "foo^bar", "foo:bar", "foo?bar", "foo*bar", "foo[bar", "foo\\bar",
] {
assert!(
validate_branch_name(name).is_err(),
"should reject: {}",
name
);
}
}
#[test]
fn validate_branch_name_rejects_invalid_start_end() {
assert!(validate_branch_name("/foo").is_err());
assert!(validate_branch_name("foo/").is_err());
assert!(validate_branch_name(".foo").is_err());
assert!(validate_branch_name("foo.").is_err());
}
#[test]
fn validate_branch_name_rejects_consecutive_slashes() {
assert!(validate_branch_name("foo//bar").is_err());
}
#[test]
fn scope_filter_none_defaults_to_global_only() {
assert_eq!(parse_scope_filter(None).unwrap(), ScopeFilter::GlobalOnly);
}
#[test]
fn scope_filter_global_returns_global_only() {
assert_eq!(
parse_scope_filter(Some("global")).unwrap(),
ScopeFilter::GlobalOnly
);
}
#[test]
fn scope_filter_project_returns_project_and_global() {
assert_eq!(
parse_scope_filter(Some("project:my-proj")).unwrap(),
ScopeFilter::ProjectAndGlobal("my-proj".to_string()),
);
}
#[test]
fn scope_filter_all_returns_all() {
assert_eq!(parse_scope_filter(Some("all")).unwrap(), ScopeFilter::All);
}
#[test]
fn scope_filter_invalid_returns_error() {
assert!(parse_scope_filter(Some("bogus")).is_err());
}
}