1use serde::{Deserialize, Serialize};
7
8#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
10#[serde(rename_all = "snake_case")]
11pub enum MemoryKind {
12 User,
14 #[serde(rename = "feedback")]
16 BehaviorPreference,
17 Project,
19 Reference,
21}
22
23impl MemoryKind {
24 pub fn label(self) -> &'static str {
25 match self {
26 Self::User => "user",
27 Self::BehaviorPreference => "feedback",
28 Self::Project => "project",
29 Self::Reference => "reference",
30 }
31 }
32
33 pub fn infer_from_metadata(metadata: &MemoryMetadata) -> Self {
35 if metadata.user_role.is_some() || metadata.expertise_level.is_some() {
36 return MemoryKind::User;
37 }
38 if metadata.preference_rule.is_some() || metadata.approved_pattern.is_some() {
39 return MemoryKind::BehaviorPreference;
40 }
41 if metadata.project_phase.is_some() || metadata.relative_date.is_some() {
42 return MemoryKind::Project;
43 }
44 if metadata.external_url.is_some() || metadata.ticket_ref.is_some() {
45 return MemoryKind::Reference;
46 }
47 MemoryKind::BehaviorPreference
49 }
50}
51
52#[derive(Debug, Clone, Serialize, Deserialize, Default)]
54pub struct MemoryMetadata {
55 pub name: String,
57
58 pub description: String,
60
61 #[serde(default, skip_serializing_if = "Option::is_none")]
63 pub kind: Option<MemoryKind>,
64
65 #[serde(default)]
67 pub created_at: u64,
68
69 #[serde(default)]
71 pub updated_at: u64,
72
73 #[serde(default, skip_serializing_if = "Option::is_none")]
75 pub session_id: Option<String>,
76
77 #[serde(default, skip_serializing_if = "Option::is_none")]
81 pub user_role: Option<String>,
82
83 #[serde(default, skip_serializing_if = "Option::is_none")]
85 pub expertise_level: Option<String>,
86
87 #[serde(default, skip_serializing_if = "Option::is_none")]
89 pub preference_rule: Option<String>,
90
91 #[serde(default, skip_serializing_if = "Option::is_none")]
93 pub approved_pattern: Option<String>,
94
95 #[serde(default, skip_serializing_if = "Option::is_none")]
97 pub project_phase: Option<String>,
98
99 #[serde(default, skip_serializing_if = "Option::is_none")]
101 pub relative_date: Option<String>,
102
103 #[serde(default, skip_serializing_if = "Option::is_none")]
105 pub external_url: Option<String>,
106
107 #[serde(default, skip_serializing_if = "Option::is_none")]
109 pub ticket_ref: Option<String>,
110}
111
112#[derive(Debug, Clone, Serialize, Deserialize)]
114pub struct MemoryWriteRequest {
115 pub metadata: MemoryMetadata,
116 pub content: String,
117}
118
119#[derive(Debug, Clone, Serialize, Deserialize)]
121pub struct MemoryQuery {
122 pub current_context: String,
124
125 #[serde(default, skip_serializing_if = "Vec::is_empty")]
127 pub active_tools: Vec<String>,
128
129 #[serde(default, skip_serializing_if = "Vec::is_empty")]
131 pub already_surfaced: Vec<String>,
132
133 #[serde(default = "default_top_k")]
135 pub top_k: usize,
136}
137
138fn default_top_k() -> usize { 5 }
139
140impl Default for MemoryQuery {
141 fn default() -> Self {
142 Self {
143 current_context: String::new(),
144 active_tools: Vec::new(),
145 already_surfaced: Vec::new(),
146 top_k: 5,
147 }
148 }
149}
150
151#[derive(Debug, Clone, Serialize, Deserialize)]
153pub struct MemoryRetrieval {
154 pub selected_memory_ids: Vec<String>,
156
157 pub selection_rationale: String,
159}
160
161#[derive(Debug, Clone, Serialize, Deserialize)]
163#[serde(tag = "error_kind", rename_all = "snake_case")]
164pub enum MemoryValidationError {
165 MissingRequiredField { field: String },
166 ContentTooLarge { size: u32, limit: u32 },
167 ForbiddenPattern { pattern: String, reason: String },
168 InvalidKind { kind: String },
169 NameTooLong { length: usize, limit: usize },
170}
171
172#[derive(Debug, Clone)]
174pub struct MemoryValidation {
175 pub max_size_bytes: u32,
176 pub max_name_length: usize,
177 pub required_fields: Vec<String>,
178 pub forbidden_patterns: Vec<(String, &'static str)>,
179}
180
181impl MemoryValidation {
182 pub fn validate(&self, request: &MemoryWriteRequest) -> Result<(), MemoryValidationError> {
184 for field in &self.required_fields {
186 match field.as_str() {
187 "name" if request.metadata.name.is_empty() => {
188 return Err(MemoryValidationError::MissingRequiredField { field: "name".into() });
189 }
190 "description" if request.metadata.description.is_empty() => {
191 return Err(MemoryValidationError::MissingRequiredField { field: "description".into() });
192 }
193 _ => {}
194 }
195 }
196
197 if request.metadata.name.len() > self.max_name_length {
199 return Err(MemoryValidationError::NameTooLong {
200 length: request.metadata.name.len(),
201 limit: self.max_name_length,
202 });
203 }
204
205 if request.content.len() > self.max_size_bytes as usize {
207 return Err(MemoryValidationError::ContentTooLarge {
208 size: request.content.len() as u32,
209 limit: self.max_size_bytes,
210 });
211 }
212
213 for (pattern, reason) in &self.forbidden_patterns {
215 if request.content.contains(pattern) {
216 return Err(MemoryValidationError::ForbiddenPattern {
217 pattern: pattern.clone(),
218 reason: reason.to_string(),
219 });
220 }
221 }
222
223 Ok(())
224 }
225}
226
227pub fn validate_memory_write(request: &MemoryWriteRequest) -> Result<(), MemoryValidationError> {
229 MemoryValidation::default().validate(request)
230}
231
232#[derive(Debug, Clone)]
245pub struct MemoryPolicy {
246 pub memory_path: String,
247 pub stale_warning_days: u32,
248 pub retrieval_top_k: usize,
249 pub validation_enabled: bool,
250 pub max_content_bytes: Option<u32>,
251 pub max_name_length: Option<usize>,
252}
253
254impl Default for MemoryPolicy {
255 fn default() -> Self {
256 Self {
257 memory_path: String::new(),
258 stale_warning_days: 2,
259 retrieval_top_k: 5,
260 validation_enabled: true,
261 max_content_bytes: None,
262 max_name_length: None,
263 }
264 }
265}
266
267impl MemoryPolicy {
268 pub fn validation(&self) -> MemoryValidation {
271 let mut v = MemoryValidation::default();
272 if let Some(bytes) = self.max_content_bytes {
273 v.max_size_bytes = bytes;
274 }
275 if let Some(len) = self.max_name_length {
276 v.max_name_length = len;
277 }
278 v
279 }
280
281 pub fn clamp_top_k(&self, requested: usize) -> usize {
283 requested.min(self.retrieval_top_k)
284 }
285}
286
287impl Default for MemoryValidation {
289 fn default() -> Self {
290 Self {
291 max_size_bytes: 10_000,
292 max_name_length: 100,
293 required_fields: vec!["name".into(), "description".into()],
294 forbidden_patterns: vec![
295 ("代码模式:".into(), "应从代码推,不应存储"),
296 ("文件路径:".into(), "应从git推,不应存储"),
297 ("架构:".into(), "应从实际代码推"),
298 ("git历史:".into(), "git log是权威"),
299 ("CLAUDE.md:".into(), "已在文档中"),
300 ("TODO:".into(), "临时任务不应进记忆"),
301 ],
302 }
303 }
304}
305
306#[cfg(test)]
307mod tests {
308 use super::*;
309
310 #[test]
311 fn memory_kind_labels_correct() {
312 assert_eq!(MemoryKind::User.label(), "user");
313 assert_eq!(MemoryKind::BehaviorPreference.label(), "feedback");
314 assert_eq!(MemoryKind::Project.label(), "project");
315 assert_eq!(MemoryKind::Reference.label(), "reference");
316 }
317
318 #[test]
319 fn infer_kind_from_user_profile_fields() {
320 let metadata = MemoryMetadata {
321 user_role: Some("Senior Engineer".into()),
322 ..Default::default()
323 };
324 assert_eq!(MemoryKind::infer_from_metadata(&metadata), MemoryKind::User);
325 }
326
327 #[test]
328 fn infer_kind_from_preference_fields() {
329 let metadata = MemoryMetadata {
330 preference_rule: Some("Always use TypeScript".into()),
331 ..Default::default()
332 };
333 assert_eq!(
334 MemoryKind::infer_from_metadata(&metadata),
335 MemoryKind::BehaviorPreference
336 );
337 }
338
339 #[test]
340 fn infer_kind_from_project_fields() {
341 let metadata = MemoryMetadata {
342 project_phase: Some("MVP".into()),
343 ..Default::default()
344 };
345 assert_eq!(MemoryKind::infer_from_metadata(&metadata), MemoryKind::Project);
346 }
347
348 #[test]
349 fn infer_kind_defaults_to_behavior_preference() {
350 let metadata = MemoryMetadata::default();
351 assert_eq!(
352 MemoryKind::infer_from_metadata(&metadata),
353 MemoryKind::BehaviorPreference
354 );
355 }
356
357 #[test]
358 fn validation_passes_for_valid_request() {
359 let validation = MemoryValidation::default();
360 let request = MemoryWriteRequest {
361 metadata: MemoryMetadata {
362 name: "test-memory".into(),
363 description: "A valid memory".into(),
364 ..Default::default()
365 },
366 content: "This is fine".to_string(),
367 };
368 assert!(validation.validate(&request).is_ok());
369 }
370
371 #[test]
372 fn validation_rejects_missing_name() {
373 let validation = MemoryValidation::default();
374 let request = MemoryWriteRequest {
375 metadata: MemoryMetadata {
376 name: "".into(),
377 description: "Missing name".into(),
378 ..Default::default()
379 },
380 content: "content".to_string(),
381 };
382 assert!(matches!(
383 validation.validate(&request),
384 Err(MemoryValidationError::MissingRequiredField { field }) if field == "name"
385 ));
386 }
387
388 #[test]
389 fn validation_rejects_forbidden_pattern() {
390 let validation = MemoryValidation::default();
391 let request = MemoryWriteRequest {
392 metadata: MemoryMetadata {
393 name: "bad-memory".into(),
394 description: "Contains forbidden pattern".into(),
395 ..Default::default()
396 },
397 content: "代码模式: 应该从代码推".to_string(),
398 };
399 assert!(matches!(
400 validation.validate(&request),
401 Err(MemoryValidationError::ForbiddenPattern { .. })
402 ));
403 }
404
405 #[test]
406 fn validation_rejects_oversized_content() {
407 let validation = MemoryValidation::default();
408 let request = MemoryWriteRequest {
409 metadata: MemoryMetadata {
410 name: "huge-memory".into(),
411 description: "Too large".into(),
412 ..Default::default()
413 },
414 content: "x".repeat(20_000),
415 };
416 assert!(matches!(
417 validation.validate(&request),
418 Err(MemoryValidationError::ContentTooLarge { .. })
419 ));
420 }
421
422 #[test]
423 fn memory_query_defaults_top_k_to_5() {
424 let query = MemoryQuery {
425 current_context: "test".into(),
426 ..Default::default()
427 };
428 assert_eq!(query.top_k, 5);
429 }
430}