Skip to main content

deepstrike_core/mm/
memory.rs

1//! Long-term memory management (Phase 7).
2//!
3//! Kernel defines memory types and validation rules; SDKs perform I/O and selection.
4//! No I/O in this module — pure classification and validation logic.
5
6use serde::{Deserialize, Serialize};
7
8/// Memory kind (4 types, mirroring Claude Code's taxonomy).
9#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
10#[serde(rename_all = "snake_case")]
11pub enum MemoryKind {
12    /// User profile: who they are, expertise level, role.
13    User,
14    /// Behavior preference: what they like/dislike, approved patterns.
15    #[serde(rename = "feedback")]
16    BehaviorPreference,
17    /// Project context: what's happening, milestones, phases.
18    Project,
19    /// External pointer: where to find things (tickets, docs).
20    Reference,
21}
22
23impl MemoryKind {
24    pub fn label(self) -> &'static str {
25        match self {
26            Self::User => "user",
27            Self::BehaviorPreference => "feedback",
28            Self::Project => "project",
29            Self::Reference => "reference",
30        }
31    }
32
33    /// Infer memory kind from metadata fields (heuristic classifier).
34    pub fn infer_from_metadata(metadata: &MemoryMetadata) -> Self {
35        if metadata.user_role.is_some() || metadata.expertise_level.is_some() {
36            return MemoryKind::User;
37        }
38        if metadata.preference_rule.is_some() || metadata.approved_pattern.is_some() {
39            return MemoryKind::BehaviorPreference;
40        }
41        if metadata.project_phase.is_some() || metadata.relative_date.is_some() {
42            return MemoryKind::Project;
43        }
44        if metadata.external_url.is_some() || metadata.ticket_ref.is_some() {
45            return MemoryKind::Reference;
46        }
47        // Default: behavior preference (most common)
48        MemoryKind::BehaviorPreference
49    }
50}
51
52/// Lightweight memory metadata (kernel stores, SDK provides full content).
53#[derive(Debug, Clone, Serialize, Deserialize, Default)]
54pub struct MemoryMetadata {
55    /// Memory slug (unique identifier).
56    pub name: String,
57
58    /// One-line description (for index display).
59    pub description: String,
60
61    /// Memory kind (optional; kernel infers if omitted).
62    #[serde(default, skip_serializing_if = "Option::is_none")]
63    pub kind: Option<MemoryKind>,
64
65    /// Creation timestamp (for stale warnings).
66    #[serde(default)]
67    pub created_at: u64,
68
69    /// Last update timestamp.
70    #[serde(default)]
71    pub updated_at: u64,
72
73    /// Associated session ID (for provenance).
74    #[serde(default, skip_serializing_if = "Option::is_none")]
75    pub session_id: Option<String>,
76
77    // --- Heuristic inference fields ---
78
79    /// User profile: role/title.
80    #[serde(default, skip_serializing_if = "Option::is_none")]
81    pub user_role: Option<String>,
82
83    /// User profile: expertise level.
84    #[serde(default, skip_serializing_if = "Option::is_none")]
85    pub expertise_level: Option<String>,
86
87    /// Behavior preference: rule/pattern.
88    #[serde(default, skip_serializing_if = "Option::is_none")]
89    pub preference_rule: Option<String>,
90
91    /// Behavior preference: approved pattern.
92    #[serde(default, skip_serializing_if = "Option::is_none")]
93    pub approved_pattern: Option<String>,
94
95    /// Project context: phase/milestone.
96    #[serde(default, skip_serializing_if = "Option::is_none")]
97    pub project_phase: Option<String>,
98
99    /// Project context: relative date (SDK must convert to absolute).
100    #[serde(default, skip_serializing_if = "Option::is_none")]
101    pub relative_date: Option<String>,
102
103    /// External pointer: URL.
104    #[serde(default, skip_serializing_if = "Option::is_none")]
105    pub external_url: Option<String>,
106
107    /// External pointer: ticket reference.
108    #[serde(default, skip_serializing_if = "Option::is_none")]
109    pub ticket_ref: Option<String>,
110}
111
112/// Memory write request (SDK → kernel).
113#[derive(Debug, Clone, Serialize, Deserialize)]
114pub struct MemoryWriteRequest {
115    pub metadata: MemoryMetadata,
116    pub content: String,
117}
118
119/// Memory query request (kernel → SDK).
120#[derive(Debug, Clone, Serialize, Deserialize)]
121pub struct MemoryQuery {
122    /// Current context summary (for selection).
123    pub current_context: String,
124
125    /// Active tools (filter recentTools).
126    #[serde(default, skip_serializing_if = "Vec::is_empty")]
127    pub active_tools: Vec<String>,
128
129    /// Recently surfaced memory IDs (filter alreadySurfaced).
130    #[serde(default, skip_serializing_if = "Vec::is_empty")]
131    pub already_surfaced: Vec<String>,
132
133    /// Return count limit (default: 5).
134    #[serde(default = "default_top_k")]
135    pub top_k: usize,
136}
137
138fn default_top_k() -> usize { 5 }
139
140impl Default for MemoryQuery {
141    fn default() -> Self {
142        Self {
143            current_context: String::new(),
144            active_tools: Vec::new(),
145            already_surfaced: Vec::new(),
146            top_k: 5,
147        }
148    }
149}
150
151/// Memory retrieval response (SDK → kernel).
152#[derive(Debug, Clone, Serialize, Deserialize)]
153pub struct MemoryRetrieval {
154    /// Selected memory IDs.
155    pub selected_memory_ids: Vec<String>,
156
157    /// Selection rationale (for kernel logging).
158    pub selection_rationale: String,
159}
160
161/// Memory validation error.
162#[derive(Debug, Clone, Serialize, Deserialize)]
163#[serde(tag = "error_kind", rename_all = "snake_case")]
164pub enum MemoryValidationError {
165    MissingRequiredField { field: String },
166    ContentTooLarge { size: u32, limit: u32 },
167    ForbiddenPattern { pattern: String, reason: String },
168    InvalidKind { kind: String },
169    NameTooLong { length: usize, limit: usize },
170}
171
172/// Memory validation rules (kernel-enforced).
173#[derive(Debug, Clone)]
174pub struct MemoryValidation {
175    pub max_size_bytes: u32,
176    pub max_name_length: usize,
177    pub required_fields: Vec<String>,
178    pub forbidden_patterns: Vec<(String, &'static str)>,
179}
180
181impl MemoryValidation {
182    /// Validate a memory write request.
183    pub fn validate(&self, request: &MemoryWriteRequest) -> Result<(), MemoryValidationError> {
184        // Check required fields
185        for field in &self.required_fields {
186            match field.as_str() {
187                "name" if request.metadata.name.is_empty() => {
188                    return Err(MemoryValidationError::MissingRequiredField { field: "name".into() });
189                }
190                "description" if request.metadata.description.is_empty() => {
191                    return Err(MemoryValidationError::MissingRequiredField { field: "description".into() });
192                }
193                _ => {}
194            }
195        }
196
197        // Check name length
198        if request.metadata.name.len() > self.max_name_length {
199            return Err(MemoryValidationError::NameTooLong {
200                length: request.metadata.name.len(),
201                limit: self.max_name_length,
202            });
203        }
204
205        // Check content size
206        if request.content.len() > self.max_size_bytes as usize {
207            return Err(MemoryValidationError::ContentTooLarge {
208                size: request.content.len() as u32,
209                limit: self.max_size_bytes,
210            });
211        }
212
213        // Check forbidden patterns
214        for (pattern, reason) in &self.forbidden_patterns {
215            if request.content.contains(pattern) {
216                return Err(MemoryValidationError::ForbiddenPattern {
217                    pattern: pattern.clone(),
218                    reason: reason.to_string(),
219                });
220            }
221        }
222
223        Ok(())
224    }
225}
226
227/// Validate a memory write request with default validation rules.
228pub fn validate_memory_write(request: &MemoryWriteRequest) -> Result<(), MemoryValidationError> {
229    MemoryValidation::default().validate(request)
230}
231
232/// Declarative configuration for the kernel's long-term memory subsystem.
233///
234/// Installed via the `set_memory_policy` input event (opt-in). When no policy is installed the
235/// kernel preserves pre-policy behavior: every `write_memory` is validated with the default rules
236/// and `query_memory` uses the requested `top_k` verbatim. Installing a policy makes these knobs
237/// authoritative:
238/// - `validation_enabled = false` admits every write without validation.
239/// - `retrieval_top_k` is an upper bound: the emitted `requested_k` is `min(query.top_k, top_k)`.
240/// - `max_content_bytes` / `max_name_length` override the validation size limits when set.
241///
242/// `memory_path` and `stale_warning_days` are not enforced inside the kernel (the kernel performs
243/// no recall I/O); they are carried so the SDK consumes a single authoritative config.
244#[derive(Debug, Clone)]
245pub struct MemoryPolicy {
246    pub memory_path: String,
247    pub stale_warning_days: u32,
248    pub retrieval_top_k: usize,
249    pub validation_enabled: bool,
250    pub max_content_bytes: Option<u32>,
251    pub max_name_length: Option<usize>,
252}
253
254impl Default for MemoryPolicy {
255    fn default() -> Self {
256        Self {
257            memory_path: String::new(),
258            stale_warning_days: 2,
259            retrieval_top_k: 5,
260            validation_enabled: true,
261            max_content_bytes: None,
262            max_name_length: None,
263        }
264    }
265}
266
267impl MemoryPolicy {
268    /// Build the validation rules this policy implies, starting from the kernel defaults and
269    /// applying any size / name-length overrides.
270    pub fn validation(&self) -> MemoryValidation {
271        let mut v = MemoryValidation::default();
272        if let Some(bytes) = self.max_content_bytes {
273            v.max_size_bytes = bytes;
274        }
275        if let Some(len) = self.max_name_length {
276            v.max_name_length = len;
277        }
278        v
279    }
280
281    /// Clamp a requested retrieval count to this policy's `retrieval_top_k` upper bound.
282    pub fn clamp_top_k(&self, requested: usize) -> usize {
283        requested.min(self.retrieval_top_k)
284    }
285}
286
287/// Default validation rules (aligned with Claude Code's "what NOT to store").
288impl Default for MemoryValidation {
289    fn default() -> Self {
290        Self {
291            max_size_bytes: 10_000,
292            max_name_length: 100,
293            required_fields: vec!["name".into(), "description".into()],
294            forbidden_patterns: vec![
295                ("代码模式:".into(), "应从代码推,不应存储"),
296                ("文件路径:".into(), "应从git推,不应存储"),
297                ("架构:".into(), "应从实际代码推"),
298                ("git历史:".into(), "git log是权威"),
299                ("CLAUDE.md:".into(), "已在文档中"),
300                ("TODO:".into(), "临时任务不应进记忆"),
301            ],
302        }
303    }
304}
305
306#[cfg(test)]
307mod tests {
308    use super::*;
309
310    #[test]
311    fn memory_kind_labels_correct() {
312        assert_eq!(MemoryKind::User.label(), "user");
313        assert_eq!(MemoryKind::BehaviorPreference.label(), "feedback");
314        assert_eq!(MemoryKind::Project.label(), "project");
315        assert_eq!(MemoryKind::Reference.label(), "reference");
316    }
317
318    #[test]
319    fn infer_kind_from_user_profile_fields() {
320        let metadata = MemoryMetadata {
321            user_role: Some("Senior Engineer".into()),
322            ..Default::default()
323        };
324        assert_eq!(MemoryKind::infer_from_metadata(&metadata), MemoryKind::User);
325    }
326
327    #[test]
328    fn infer_kind_from_preference_fields() {
329        let metadata = MemoryMetadata {
330            preference_rule: Some("Always use TypeScript".into()),
331            ..Default::default()
332        };
333        assert_eq!(
334            MemoryKind::infer_from_metadata(&metadata),
335            MemoryKind::BehaviorPreference
336        );
337    }
338
339    #[test]
340    fn infer_kind_from_project_fields() {
341        let metadata = MemoryMetadata {
342            project_phase: Some("MVP".into()),
343            ..Default::default()
344        };
345        assert_eq!(MemoryKind::infer_from_metadata(&metadata), MemoryKind::Project);
346    }
347
348    #[test]
349    fn infer_kind_defaults_to_behavior_preference() {
350        let metadata = MemoryMetadata::default();
351        assert_eq!(
352            MemoryKind::infer_from_metadata(&metadata),
353            MemoryKind::BehaviorPreference
354        );
355    }
356
357    #[test]
358    fn validation_passes_for_valid_request() {
359        let validation = MemoryValidation::default();
360        let request = MemoryWriteRequest {
361            metadata: MemoryMetadata {
362                name: "test-memory".into(),
363                description: "A valid memory".into(),
364                ..Default::default()
365            },
366            content: "This is fine".to_string(),
367        };
368        assert!(validation.validate(&request).is_ok());
369    }
370
371    #[test]
372    fn validation_rejects_missing_name() {
373        let validation = MemoryValidation::default();
374        let request = MemoryWriteRequest {
375            metadata: MemoryMetadata {
376                name: "".into(),
377                description: "Missing name".into(),
378                ..Default::default()
379            },
380            content: "content".to_string(),
381        };
382        assert!(matches!(
383            validation.validate(&request),
384            Err(MemoryValidationError::MissingRequiredField { field }) if field == "name"
385        ));
386    }
387
388    #[test]
389    fn validation_rejects_forbidden_pattern() {
390        let validation = MemoryValidation::default();
391        let request = MemoryWriteRequest {
392            metadata: MemoryMetadata {
393                name: "bad-memory".into(),
394                description: "Contains forbidden pattern".into(),
395                ..Default::default()
396            },
397            content: "代码模式: 应该从代码推".to_string(),
398        };
399        assert!(matches!(
400            validation.validate(&request),
401            Err(MemoryValidationError::ForbiddenPattern { .. })
402        ));
403    }
404
405    #[test]
406    fn validation_rejects_oversized_content() {
407        let validation = MemoryValidation::default();
408        let request = MemoryWriteRequest {
409            metadata: MemoryMetadata {
410                name: "huge-memory".into(),
411                description: "Too large".into(),
412                ..Default::default()
413            },
414            content: "x".repeat(20_000),
415        };
416        assert!(matches!(
417            validation.validate(&request),
418            Err(MemoryValidationError::ContentTooLarge { .. })
419        ));
420    }
421
422    #[test]
423    fn memory_query_defaults_top_k_to_5() {
424        let query = MemoryQuery {
425            current_context: "test".into(),
426            ..Default::default()
427        };
428        assert_eq!(query.top_k, 5);
429    }
430}