use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum MemoryKind {
User,
#[serde(rename = "feedback")]
BehaviorPreference,
Project,
Reference,
}
impl MemoryKind {
pub fn label(self) -> &'static str {
match self {
Self::User => "user",
Self::BehaviorPreference => "feedback",
Self::Project => "project",
Self::Reference => "reference",
}
}
pub fn infer_from_metadata(metadata: &MemoryMetadata) -> Self {
if metadata.user_role.is_some() || metadata.expertise_level.is_some() {
return MemoryKind::User;
}
if metadata.preference_rule.is_some() || metadata.approved_pattern.is_some() {
return MemoryKind::BehaviorPreference;
}
if metadata.project_phase.is_some() || metadata.relative_date.is_some() {
return MemoryKind::Project;
}
if metadata.external_url.is_some() || metadata.ticket_ref.is_some() {
return MemoryKind::Reference;
}
MemoryKind::BehaviorPreference
}
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct MemoryMetadata {
pub name: String,
pub description: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub kind: Option<MemoryKind>,
#[serde(default)]
pub created_at: u64,
#[serde(default)]
pub updated_at: u64,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub session_id: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub user_role: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub expertise_level: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub preference_rule: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub approved_pattern: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub project_phase: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub relative_date: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub external_url: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub ticket_ref: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MemoryWriteRequest {
pub metadata: MemoryMetadata,
pub content: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MemoryQuery {
pub current_context: String,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub active_tools: Vec<String>,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub already_surfaced: Vec<String>,
#[serde(default = "default_top_k")]
pub top_k: usize,
}
fn default_top_k() -> usize { 5 }
impl Default for MemoryQuery {
fn default() -> Self {
Self {
current_context: String::new(),
active_tools: Vec::new(),
already_surfaced: Vec::new(),
top_k: 5,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MemoryRetrieval {
pub selected_memory_ids: Vec<String>,
pub selection_rationale: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "error_kind", rename_all = "snake_case")]
pub enum MemoryValidationError {
MissingRequiredField { field: String },
ContentTooLarge { size: u32, limit: u32 },
ForbiddenPattern { pattern: String, reason: String },
InvalidKind { kind: String },
NameTooLong { length: usize, limit: usize },
}
#[derive(Debug, Clone)]
pub struct MemoryValidation {
pub max_size_bytes: u32,
pub max_name_length: usize,
pub required_fields: Vec<String>,
pub forbidden_patterns: Vec<(String, &'static str)>,
}
impl MemoryValidation {
pub fn validate(&self, request: &MemoryWriteRequest) -> Result<(), MemoryValidationError> {
for field in &self.required_fields {
match field.as_str() {
"name" if request.metadata.name.is_empty() => {
return Err(MemoryValidationError::MissingRequiredField { field: "name".into() });
}
"description" if request.metadata.description.is_empty() => {
return Err(MemoryValidationError::MissingRequiredField { field: "description".into() });
}
_ => {}
}
}
if request.metadata.name.len() > self.max_name_length {
return Err(MemoryValidationError::NameTooLong {
length: request.metadata.name.len(),
limit: self.max_name_length,
});
}
if request.content.len() > self.max_size_bytes as usize {
return Err(MemoryValidationError::ContentTooLarge {
size: request.content.len() as u32,
limit: self.max_size_bytes,
});
}
for (pattern, reason) in &self.forbidden_patterns {
if request.content.contains(pattern) {
return Err(MemoryValidationError::ForbiddenPattern {
pattern: pattern.clone(),
reason: reason.to_string(),
});
}
}
Ok(())
}
}
pub fn validate_memory_write(request: &MemoryWriteRequest) -> Result<(), MemoryValidationError> {
MemoryValidation::default().validate(request)
}
#[derive(Debug, Clone)]
pub struct MemoryPolicy {
pub memory_path: String,
pub stale_warning_days: u32,
pub retrieval_top_k: usize,
pub validation_enabled: bool,
pub max_content_bytes: Option<u32>,
pub max_name_length: Option<usize>,
}
impl Default for MemoryPolicy {
fn default() -> Self {
Self {
memory_path: String::new(),
stale_warning_days: 2,
retrieval_top_k: 5,
validation_enabled: true,
max_content_bytes: None,
max_name_length: None,
}
}
}
impl MemoryPolicy {
pub fn validation(&self) -> MemoryValidation {
let mut v = MemoryValidation::default();
if let Some(bytes) = self.max_content_bytes {
v.max_size_bytes = bytes;
}
if let Some(len) = self.max_name_length {
v.max_name_length = len;
}
v
}
pub fn clamp_top_k(&self, requested: usize) -> usize {
requested.min(self.retrieval_top_k)
}
}
impl Default for MemoryValidation {
fn default() -> Self {
Self {
max_size_bytes: 10_000,
max_name_length: 100,
required_fields: vec!["name".into(), "description".into()],
forbidden_patterns: vec![
("代码模式:".into(), "应从代码推,不应存储"),
("文件路径:".into(), "应从git推,不应存储"),
("架构:".into(), "应从实际代码推"),
("git历史:".into(), "git log是权威"),
("CLAUDE.md:".into(), "已在文档中"),
("TODO:".into(), "临时任务不应进记忆"),
],
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn memory_kind_labels_correct() {
assert_eq!(MemoryKind::User.label(), "user");
assert_eq!(MemoryKind::BehaviorPreference.label(), "feedback");
assert_eq!(MemoryKind::Project.label(), "project");
assert_eq!(MemoryKind::Reference.label(), "reference");
}
#[test]
fn infer_kind_from_user_profile_fields() {
let metadata = MemoryMetadata {
user_role: Some("Senior Engineer".into()),
..Default::default()
};
assert_eq!(MemoryKind::infer_from_metadata(&metadata), MemoryKind::User);
}
#[test]
fn infer_kind_from_preference_fields() {
let metadata = MemoryMetadata {
preference_rule: Some("Always use TypeScript".into()),
..Default::default()
};
assert_eq!(
MemoryKind::infer_from_metadata(&metadata),
MemoryKind::BehaviorPreference
);
}
#[test]
fn infer_kind_from_project_fields() {
let metadata = MemoryMetadata {
project_phase: Some("MVP".into()),
..Default::default()
};
assert_eq!(MemoryKind::infer_from_metadata(&metadata), MemoryKind::Project);
}
#[test]
fn infer_kind_defaults_to_behavior_preference() {
let metadata = MemoryMetadata::default();
assert_eq!(
MemoryKind::infer_from_metadata(&metadata),
MemoryKind::BehaviorPreference
);
}
#[test]
fn validation_passes_for_valid_request() {
let validation = MemoryValidation::default();
let request = MemoryWriteRequest {
metadata: MemoryMetadata {
name: "test-memory".into(),
description: "A valid memory".into(),
..Default::default()
},
content: "This is fine".to_string(),
};
assert!(validation.validate(&request).is_ok());
}
#[test]
fn validation_rejects_missing_name() {
let validation = MemoryValidation::default();
let request = MemoryWriteRequest {
metadata: MemoryMetadata {
name: "".into(),
description: "Missing name".into(),
..Default::default()
},
content: "content".to_string(),
};
assert!(matches!(
validation.validate(&request),
Err(MemoryValidationError::MissingRequiredField { field }) if field == "name"
));
}
#[test]
fn validation_rejects_forbidden_pattern() {
let validation = MemoryValidation::default();
let request = MemoryWriteRequest {
metadata: MemoryMetadata {
name: "bad-memory".into(),
description: "Contains forbidden pattern".into(),
..Default::default()
},
content: "代码模式: 应该从代码推".to_string(),
};
assert!(matches!(
validation.validate(&request),
Err(MemoryValidationError::ForbiddenPattern { .. })
));
}
#[test]
fn validation_rejects_oversized_content() {
let validation = MemoryValidation::default();
let request = MemoryWriteRequest {
metadata: MemoryMetadata {
name: "huge-memory".into(),
description: "Too large".into(),
..Default::default()
},
content: "x".repeat(20_000),
};
assert!(matches!(
validation.validate(&request),
Err(MemoryValidationError::ContentTooLarge { .. })
));
}
#[test]
fn memory_query_defaults_top_k_to_5() {
let query = MemoryQuery {
current_context: "test".into(),
..Default::default()
};
assert_eq!(query.top_k, 5);
}
}