Skip to main content

oxios_memory/memory/
auto_classify.rs

1//! Automatic memory type classification from content.
2//!
3//! Infers `MemoryType` from content text using pattern matching.
4//! Used when memories are stored without explicit type, and by
5//! Dream Phase 2 for re-classification.
6
7use crate::memory::types::MemoryType;
8
9// ---------------------------------------------------------------------------
10// Pattern constants
11// ---------------------------------------------------------------------------
12
13/// Patterns indicating a user correction (contradiction of previous info).
14const CORRECTION_PATTERNS: &[&str] = &[
15    "actually",
16    "no, it's",
17    "that's wrong",
18    "correction",
19    "i meant",
20    "not that",
21    "i was wrong",
22    // Korean corrections (정정/반박)
23    "아니라",
24    "실제로는",
25    "정정",
26    "틀렸",
27    "잘못 알고",
28    "내가 틀렸",
29    "고쳐",
30];
31
32/// Patterns indicating a preference or taste.
33const PREFERENCE_PATTERNS: &[&str] = &[
34    "i prefer",
35    "always use",
36    "i like",
37    "i don't",
38    "never use",
39    "i'd rather",
40    "my preference",
41    "please use",
42    "make sure to use",
43    // Korean preferences (선호/취향)
44    "선호",
45    "좋아",
46    "내가 좋아하",
47    "항상 써",
48    "절대 쓰지",
49    "차라리",
50    "내 취향",
51];
52
53/// Patterns indicating a decision.
54const DECISION_PATTERNS: &[&str] = &[
55    "decided to",
56    "we chose",
57    "let's go with",
58    "we'll use",
59    "i decided",
60    "the decision is",
61    "going with",
62    // Korean decisions (결정)
63    "결정했",
64    "결정했어",
65    "하기로 했",
66    "선택했",
67    "우리는",
68    "사용하기로",
69];
70
71/// Patterns indicating a skill/procedure.
72const SKILL_PATTERNS: &[&str] = &[
73    "always run",
74    "before commit",
75    "every time",
76    "make sure to",
77    "workflow is",
78    "standard procedure",
79    "first, then",
80    "step by step",
81    // Korean skills/procedures (절차/방법)
82    "하는 방법",
83    "이렇게 하는",
84    "항상 실행",
85    "커밋하기 전",
86    "표준 절차",
87    "순서대로",
88    "먼저 그리고",
89];
90
91/// Patterns indicating profile information.
92const PROFILE_PATTERNS: &[&str] = &[
93    "my name is",
94    "i work at",
95    "i'm a ",
96    "i am a ",
97    "my role is",
98    "my job is",
99    "i specialize",
100    "my background",
101    // Korean profile (프로필/신상)
102    "내 이름은",
103    "제 이름은",
104    "나는 ",
105    "저는 ",
106    "직업은",
107    "일하고 있",
108    "전문",
109    "내 배경",
110];
111
112/// Patterns indicating an episode/event.
113const EPISODE_PATTERNS: &[&str] = &[
114    "deployed",
115    "released",
116    "launched",
117    "completed",
118    "finished",
119    "started",
120];
121
122// ---------------------------------------------------------------------------
123// AutoClassifier
124// ---------------------------------------------------------------------------
125
126/// Automatic memory type classifier.
127///
128/// Uses pattern matching to infer memory types from content text.
129/// Falls back to `Fact` when no specific type is detected.
130pub struct AutoClassifier;
131
132impl AutoClassifier {
133    /// Classify a new memory entry from its content and optional context.
134    ///
135    /// Returns the inferred `MemoryType`. Falls back to `Fact` if no
136    /// specific type can be determined.
137    pub fn infer_memory_type(content: &str, _context: &str) -> MemoryType {
138        let content_lower = content.to_lowercase();
139
140        // Priority order:
141        // 1. Correction → Fact (overrides everything)
142        // 2. Preference
143        // 3. Decision
144        // 4. Skill/Procedure
145        // 5. Profile
146        // 6. Episode
147        // 7. Default → Fact
148
149        if Self::is_correction(&content_lower) {
150            return MemoryType::Fact;
151        }
152
153        if Self::is_preference(&content_lower) {
154            return MemoryType::Preference;
155        }
156
157        if Self::is_decision(&content_lower) {
158            return MemoryType::Decision;
159        }
160
161        if Self::is_skill(&content_lower) {
162            return MemoryType::Skill;
163        }
164
165        if Self::is_profile(&content_lower) {
166            return MemoryType::UserProfile;
167        }
168
169        if Self::is_episode(&content_lower) {
170            return MemoryType::Episode;
171        }
172
173        MemoryType::Fact
174    }
175
176    fn is_correction(content_lower: &str) -> bool {
177        CORRECTION_PATTERNS
178            .iter()
179            .any(|p| content_lower.contains(p))
180    }
181
182    fn is_preference(content_lower: &str) -> bool {
183        PREFERENCE_PATTERNS
184            .iter()
185            .any(|p| content_lower.contains(p))
186    }
187
188    fn is_decision(content_lower: &str) -> bool {
189        DECISION_PATTERNS.iter().any(|p| content_lower.contains(p))
190    }
191
192    fn is_skill(content_lower: &str) -> bool {
193        SKILL_PATTERNS.iter().any(|p| content_lower.contains(p))
194    }
195
196    fn is_profile(content_lower: &str) -> bool {
197        PROFILE_PATTERNS.iter().any(|p| content_lower.contains(p))
198    }
199
200    fn is_episode(content_lower: &str) -> bool {
201        EPISODE_PATTERNS.iter().any(|p| content_lower.contains(p))
202    }
203
204    /// Extract tags from content for search indexing.
205    ///
206    /// Keyword extraction: split on whitespace, filter short/stop words, then
207    /// pick the top-N by frequency (ties broken alphabetically) so that the
208    /// most salient terms are retained instead of an arbitrary alphabetical
209    /// prefix.
210    pub fn extract_tags(content: &str, max_tags: usize) -> Vec<String> {
211        use std::collections::HashMap;
212
213        let mut counts: HashMap<String, u32> = HashMap::new();
214        for word in content.split_whitespace() {
215            let w = word
216                .trim_matches(|c: char| c.is_ascii_punctuation())
217                .to_lowercase();
218            if w.len() > 3 && !Self::is_stop_word(&w) {
219                *counts.entry(w).or_default() += 1;
220            }
221        }
222
223        let mut tags: Vec<(String, u32)> = counts.into_iter().collect();
224        // Sort by frequency descending, then alphabetically for determinism.
225        tags.sort_by(|a, b| b.1.cmp(&a.1).then_with(|| a.0.cmp(&b.0)));
226        tags.into_iter().take(max_tags).map(|(w, _)| w).collect()
227    }
228
229    fn is_stop_word(word: &str) -> bool {
230        const STOP: &[&str] = &[
231            "that", "this", "with", "from", "have", "been", "were", "will", "would", "could",
232            "should", "about", "which", "their", "there", "these", "those", "other", "than",
233            "then", "also", "some",
234        ];
235        STOP.contains(&word)
236    }
237}
238
239// ---------------------------------------------------------------------------
240// Tests
241// ---------------------------------------------------------------------------
242
243#[cfg(test)]
244mod tests {
245    use super::*;
246
247    #[test]
248    fn test_classify_correction() {
249        assert_eq!(
250            AutoClassifier::infer_memory_type("Actually, the port is 8080 not 3000", ""),
251            MemoryType::Fact
252        );
253        assert_eq!(
254            AutoClassifier::infer_memory_type("Correction: the API key expired", ""),
255            MemoryType::Fact
256        );
257    }
258
259    #[test]
260    fn test_classify_preference() {
261        assert_eq!(
262            AutoClassifier::infer_memory_type("I prefer dark mode for the editor", ""),
263            MemoryType::Preference
264        );
265        assert_eq!(
266            AutoClassifier::infer_memory_type("Never use tabs, always use spaces", ""),
267            MemoryType::Preference
268        );
269    }
270
271    #[test]
272    fn test_classify_decision() {
273        assert_eq!(
274            AutoClassifier::infer_memory_type("We decided to use Tokio for async runtime", ""),
275            MemoryType::Decision
276        );
277        assert_eq!(
278            AutoClassifier::infer_memory_type("Let's go with the microservice approach", ""),
279            MemoryType::Decision
280        );
281    }
282
283    #[test]
284    fn test_classify_skill() {
285        assert_eq!(
286            AutoClassifier::infer_memory_type("Always run tests before commit", ""),
287            MemoryType::Skill
288        );
289        assert_eq!(
290            AutoClassifier::infer_memory_type("Standard procedure: lint, test, then deploy", ""),
291            MemoryType::Skill
292        );
293    }
294
295    #[test]
296    fn test_classify_profile() {
297        assert_eq!(
298            AutoClassifier::infer_memory_type("My name is Won and I work at Oxios", ""),
299            MemoryType::UserProfile
300        );
301        assert_eq!(
302            AutoClassifier::infer_memory_type("I'm a backend engineer", ""),
303            MemoryType::UserProfile
304        );
305    }
306
307    #[test]
308    fn test_classify_episode() {
309        assert_eq!(
310            AutoClassifier::infer_memory_type("Released v0.2.0 with memory consolidation", ""),
311            MemoryType::Episode
312        );
313        assert_eq!(
314            AutoClassifier::infer_memory_type("Deployed the new API gateway yesterday", ""),
315            MemoryType::Episode
316        );
317    }
318
319    #[test]
320    fn test_classify_default_fact() {
321        assert_eq!(
322            AutoClassifier::infer_memory_type("API uses port 3000", ""),
323            MemoryType::Fact
324        );
325        assert_eq!(
326            AutoClassifier::infer_memory_type("The database has 42 tables", ""),
327            MemoryType::Fact
328        );
329    }
330
331    #[test]
332    fn test_extract_tags() {
333        let tags =
334            AutoClassifier::extract_tags("Rust tokio async runtime memory consolidation system", 5);
335        assert!(!tags.is_empty());
336        assert!(
337            tags.iter()
338                .any(|t| t.contains("rust") || t.contains("memory"))
339        );
340    }
341}