1use crate::memory::types::MemoryType;
8
9const CORRECTION_PATTERNS: &[&str] = &[
15 "actually",
16 "no, it's",
17 "that's wrong",
18 "correction",
19 "i meant",
20 "not that",
21 "i was wrong",
22];
23
24const PREFERENCE_PATTERNS: &[&str] = &[
26 "i prefer",
27 "always use",
28 "i like",
29 "i don't",
30 "never use",
31 "i'd rather",
32 "my preference",
33 "please use",
34 "make sure to use",
35];
36
37const DECISION_PATTERNS: &[&str] = &[
39 "decided to",
40 "we chose",
41 "let's go with",
42 "we'll use",
43 "i decided",
44 "the decision is",
45 "going with",
46];
47
48const SKILL_PATTERNS: &[&str] = &[
50 "always run",
51 "before commit",
52 "every time",
53 "make sure to",
54 "workflow is",
55 "standard procedure",
56 "first, then",
57 "step by step",
58];
59
60const PROFILE_PATTERNS: &[&str] = &[
62 "my name is",
63 "i work at",
64 "i'm a ",
65 "i am a ",
66 "my role is",
67 "my job is",
68 "i specialize",
69 "my background",
70];
71
72const EPISODE_PATTERNS: &[&str] = &[
74 "deployed",
75 "released",
76 "launched",
77 "completed",
78 "finished",
79 "started",
80];
81
82pub struct AutoClassifier;
91
92impl AutoClassifier {
93 pub fn infer_memory_type(content: &str, _context: &str) -> MemoryType {
98 let content_lower = content.to_lowercase();
99
100 if Self::is_correction(&content_lower) {
110 return MemoryType::Fact;
111 }
112
113 if Self::is_preference(&content_lower) {
114 return MemoryType::Preference;
115 }
116
117 if Self::is_decision(&content_lower) {
118 return MemoryType::Decision;
119 }
120
121 if Self::is_skill(&content_lower) {
122 return MemoryType::Skill;
123 }
124
125 if Self::is_profile(&content_lower) {
126 return MemoryType::UserProfile;
127 }
128
129 if Self::is_episode(&content_lower) {
130 return MemoryType::Episode;
131 }
132
133 MemoryType::Fact
134 }
135
136 fn is_correction(content_lower: &str) -> bool {
137 CORRECTION_PATTERNS
138 .iter()
139 .any(|p| content_lower.contains(p))
140 }
141
142 fn is_preference(content_lower: &str) -> bool {
143 PREFERENCE_PATTERNS
144 .iter()
145 .any(|p| content_lower.contains(p))
146 }
147
148 fn is_decision(content_lower: &str) -> bool {
149 DECISION_PATTERNS.iter().any(|p| content_lower.contains(p))
150 }
151
152 fn is_skill(content_lower: &str) -> bool {
153 SKILL_PATTERNS.iter().any(|p| content_lower.contains(p))
154 }
155
156 fn is_profile(content_lower: &str) -> bool {
157 PROFILE_PATTERNS.iter().any(|p| content_lower.contains(p))
158 }
159
160 fn is_episode(content_lower: &str) -> bool {
161 EPISODE_PATTERNS.iter().any(|p| content_lower.contains(p))
162 }
163
164 pub fn extract_tags(content: &str, max_tags: usize) -> Vec<String> {
169 let mut tags: Vec<String> = content
170 .split_whitespace()
171 .map(|w| {
172 w.trim_matches(|c: char| c.is_ascii_punctuation())
173 .to_lowercase()
174 })
175 .filter(|w| w.len() > 3 && !Self::is_stop_word(w))
176 .collect();
177
178 tags.sort();
179 tags.dedup();
180 tags.truncate(max_tags);
181 tags
182 }
183
184 fn is_stop_word(word: &str) -> bool {
185 const STOP: &[&str] = &[
186 "that", "this", "with", "from", "have", "been", "were", "will", "would", "could",
187 "should", "about", "which", "their", "there", "these", "those", "other", "than",
188 "then", "also", "some",
189 ];
190 STOP.contains(&word)
191 }
192}
193
194#[cfg(test)]
199mod tests {
200 use super::*;
201
202 #[test]
203 fn test_classify_correction() {
204 assert_eq!(
205 AutoClassifier::infer_memory_type("Actually, the port is 8080 not 3000", ""),
206 MemoryType::Fact
207 );
208 assert_eq!(
209 AutoClassifier::infer_memory_type("Correction: the API key expired", ""),
210 MemoryType::Fact
211 );
212 }
213
214 #[test]
215 fn test_classify_preference() {
216 assert_eq!(
217 AutoClassifier::infer_memory_type("I prefer dark mode for the editor", ""),
218 MemoryType::Preference
219 );
220 assert_eq!(
221 AutoClassifier::infer_memory_type("Never use tabs, always use spaces", ""),
222 MemoryType::Preference
223 );
224 }
225
226 #[test]
227 fn test_classify_decision() {
228 assert_eq!(
229 AutoClassifier::infer_memory_type("We decided to use Tokio for async runtime", ""),
230 MemoryType::Decision
231 );
232 assert_eq!(
233 AutoClassifier::infer_memory_type("Let's go with the microservice approach", ""),
234 MemoryType::Decision
235 );
236 }
237
238 #[test]
239 fn test_classify_skill() {
240 assert_eq!(
241 AutoClassifier::infer_memory_type("Always run tests before commit", ""),
242 MemoryType::Skill
243 );
244 assert_eq!(
245 AutoClassifier::infer_memory_type("Standard procedure: lint, test, then deploy", ""),
246 MemoryType::Skill
247 );
248 }
249
250 #[test]
251 fn test_classify_profile() {
252 assert_eq!(
253 AutoClassifier::infer_memory_type("My name is Won and I work at Oxios", ""),
254 MemoryType::UserProfile
255 );
256 assert_eq!(
257 AutoClassifier::infer_memory_type("I'm a backend engineer", ""),
258 MemoryType::UserProfile
259 );
260 }
261
262 #[test]
263 fn test_classify_episode() {
264 assert_eq!(
265 AutoClassifier::infer_memory_type("Released v0.2.0 with memory consolidation", ""),
266 MemoryType::Episode
267 );
268 assert_eq!(
269 AutoClassifier::infer_memory_type("Deployed the new API gateway yesterday", ""),
270 MemoryType::Episode
271 );
272 }
273
274 #[test]
275 fn test_classify_default_fact() {
276 assert_eq!(
277 AutoClassifier::infer_memory_type("API uses port 3000", ""),
278 MemoryType::Fact
279 );
280 assert_eq!(
281 AutoClassifier::infer_memory_type("The database has 42 tables", ""),
282 MemoryType::Fact
283 );
284 }
285
286 #[test]
287 fn test_extract_tags() {
288 let tags =
289 AutoClassifier::extract_tags("Rust tokio async runtime memory consolidation system", 5);
290 assert!(!tags.is_empty());
291 assert!(tags
292 .iter()
293 .any(|t| t.contains("rust") || t.contains("memory")));
294 }
295}