Skip to main content

oxios_memory/memory/
auto_classify.rs

1//! Automatic memory type classification from content.
2//!
3//! Infers `MemoryType` from content text using pattern matching.
4//! Used when memories are stored without explicit type, and by
5//! Dream Phase 2 for re-classification.
6
7use crate::memory::types::MemoryType;
8
9// ---------------------------------------------------------------------------
10// Pattern constants
11// ---------------------------------------------------------------------------
12
13/// Patterns indicating a user correction (contradiction of previous info).
14const CORRECTION_PATTERNS: &[&str] = &[
15    "actually",
16    "no, it's",
17    "that's wrong",
18    "correction",
19    "i meant",
20    "not that",
21    "i was wrong",
22];
23
24/// Patterns indicating a preference or taste.
25const PREFERENCE_PATTERNS: &[&str] = &[
26    "i prefer",
27    "always use",
28    "i like",
29    "i don't",
30    "never use",
31    "i'd rather",
32    "my preference",
33    "please use",
34    "make sure to use",
35];
36
37/// Patterns indicating a decision.
38const DECISION_PATTERNS: &[&str] = &[
39    "decided to",
40    "we chose",
41    "let's go with",
42    "we'll use",
43    "i decided",
44    "the decision is",
45    "going with",
46];
47
48/// Patterns indicating a skill/procedure.
49const SKILL_PATTERNS: &[&str] = &[
50    "always run",
51    "before commit",
52    "every time",
53    "make sure to",
54    "workflow is",
55    "standard procedure",
56    "first, then",
57    "step by step",
58];
59
60/// Patterns indicating profile information.
61const PROFILE_PATTERNS: &[&str] = &[
62    "my name is",
63    "i work at",
64    "i'm a ",
65    "i am a ",
66    "my role is",
67    "my job is",
68    "i specialize",
69    "my background",
70];
71
72/// Patterns indicating an episode/event.
73const EPISODE_PATTERNS: &[&str] = &[
74    "deployed",
75    "released",
76    "launched",
77    "completed",
78    "finished",
79    "started",
80];
81
82// ---------------------------------------------------------------------------
83// AutoClassifier
84// ---------------------------------------------------------------------------
85
86/// Automatic memory type classifier.
87///
88/// Uses pattern matching to infer memory types from content text.
89/// Falls back to `Fact` when no specific type is detected.
90pub struct AutoClassifier;
91
92impl AutoClassifier {
93    /// Classify a new memory entry from its content and optional context.
94    ///
95    /// Returns the inferred `MemoryType`. Falls back to `Fact` if no
96    /// specific type can be determined.
97    pub fn infer_memory_type(content: &str, _context: &str) -> MemoryType {
98        let content_lower = content.to_lowercase();
99
100        // Priority order:
101        // 1. Correction → Fact (overrides everything)
102        // 2. Preference
103        // 3. Decision
104        // 4. Skill/Procedure
105        // 5. Profile
106        // 6. Episode
107        // 7. Default → Fact
108
109        if Self::is_correction(&content_lower) {
110            return MemoryType::Fact;
111        }
112
113        if Self::is_preference(&content_lower) {
114            return MemoryType::Preference;
115        }
116
117        if Self::is_decision(&content_lower) {
118            return MemoryType::Decision;
119        }
120
121        if Self::is_skill(&content_lower) {
122            return MemoryType::Skill;
123        }
124
125        if Self::is_profile(&content_lower) {
126            return MemoryType::UserProfile;
127        }
128
129        if Self::is_episode(&content_lower) {
130            return MemoryType::Episode;
131        }
132
133        MemoryType::Fact
134    }
135
136    fn is_correction(content_lower: &str) -> bool {
137        CORRECTION_PATTERNS
138            .iter()
139            .any(|p| content_lower.contains(p))
140    }
141
142    fn is_preference(content_lower: &str) -> bool {
143        PREFERENCE_PATTERNS
144            .iter()
145            .any(|p| content_lower.contains(p))
146    }
147
148    fn is_decision(content_lower: &str) -> bool {
149        DECISION_PATTERNS.iter().any(|p| content_lower.contains(p))
150    }
151
152    fn is_skill(content_lower: &str) -> bool {
153        SKILL_PATTERNS.iter().any(|p| content_lower.contains(p))
154    }
155
156    fn is_profile(content_lower: &str) -> bool {
157        PROFILE_PATTERNS.iter().any(|p| content_lower.contains(p))
158    }
159
160    fn is_episode(content_lower: &str) -> bool {
161        EPISODE_PATTERNS.iter().any(|p| content_lower.contains(p))
162    }
163
164    /// Extract tags from content for search indexing.
165    ///
166    /// Simple keyword extraction: split on whitespace, filter short words,
167    /// take top N unique terms.
168    pub fn extract_tags(content: &str, max_tags: usize) -> Vec<String> {
169        let mut tags: Vec<String> = content
170            .split_whitespace()
171            .map(|w| {
172                w.trim_matches(|c: char| c.is_ascii_punctuation())
173                    .to_lowercase()
174            })
175            .filter(|w| w.len() > 3 && !Self::is_stop_word(w))
176            .collect();
177
178        tags.sort();
179        tags.dedup();
180        tags.truncate(max_tags);
181        tags
182    }
183
184    fn is_stop_word(word: &str) -> bool {
185        const STOP: &[&str] = &[
186            "that", "this", "with", "from", "have", "been", "were", "will", "would", "could",
187            "should", "about", "which", "their", "there", "these", "those", "other", "than",
188            "then", "also", "some",
189        ];
190        STOP.contains(&word)
191    }
192}
193
194// ---------------------------------------------------------------------------
195// Tests
196// ---------------------------------------------------------------------------
197
198#[cfg(test)]
199mod tests {
200    use super::*;
201
202    #[test]
203    fn test_classify_correction() {
204        assert_eq!(
205            AutoClassifier::infer_memory_type("Actually, the port is 8080 not 3000", ""),
206            MemoryType::Fact
207        );
208        assert_eq!(
209            AutoClassifier::infer_memory_type("Correction: the API key expired", ""),
210            MemoryType::Fact
211        );
212    }
213
214    #[test]
215    fn test_classify_preference() {
216        assert_eq!(
217            AutoClassifier::infer_memory_type("I prefer dark mode for the editor", ""),
218            MemoryType::Preference
219        );
220        assert_eq!(
221            AutoClassifier::infer_memory_type("Never use tabs, always use spaces", ""),
222            MemoryType::Preference
223        );
224    }
225
226    #[test]
227    fn test_classify_decision() {
228        assert_eq!(
229            AutoClassifier::infer_memory_type("We decided to use Tokio for async runtime", ""),
230            MemoryType::Decision
231        );
232        assert_eq!(
233            AutoClassifier::infer_memory_type("Let's go with the microservice approach", ""),
234            MemoryType::Decision
235        );
236    }
237
238    #[test]
239    fn test_classify_skill() {
240        assert_eq!(
241            AutoClassifier::infer_memory_type("Always run tests before commit", ""),
242            MemoryType::Skill
243        );
244        assert_eq!(
245            AutoClassifier::infer_memory_type("Standard procedure: lint, test, then deploy", ""),
246            MemoryType::Skill
247        );
248    }
249
250    #[test]
251    fn test_classify_profile() {
252        assert_eq!(
253            AutoClassifier::infer_memory_type("My name is Won and I work at Oxios", ""),
254            MemoryType::UserProfile
255        );
256        assert_eq!(
257            AutoClassifier::infer_memory_type("I'm a backend engineer", ""),
258            MemoryType::UserProfile
259        );
260    }
261
262    #[test]
263    fn test_classify_episode() {
264        assert_eq!(
265            AutoClassifier::infer_memory_type("Released v0.2.0 with memory consolidation", ""),
266            MemoryType::Episode
267        );
268        assert_eq!(
269            AutoClassifier::infer_memory_type("Deployed the new API gateway yesterday", ""),
270            MemoryType::Episode
271        );
272    }
273
274    #[test]
275    fn test_classify_default_fact() {
276        assert_eq!(
277            AutoClassifier::infer_memory_type("API uses port 3000", ""),
278            MemoryType::Fact
279        );
280        assert_eq!(
281            AutoClassifier::infer_memory_type("The database has 42 tables", ""),
282            MemoryType::Fact
283        );
284    }
285
286    #[test]
287    fn test_extract_tags() {
288        let tags =
289            AutoClassifier::extract_tags("Rust tokio async runtime memory consolidation system", 5);
290        assert!(!tags.is_empty());
291        assert!(tags
292            .iter()
293            .any(|t| t.contains("rust") || t.contains("memory")));
294    }
295}