1use std::collections::HashMap;
4
5use super::config::MIN_MEMORY_CONTENT_LENGTH;
6use super::extractor::infer_category_from_content;
7use super::retrieval::extract_context_keywords;
8use super::types::{AutoMemory, MemoryCategory, MemoryEntry};
9
10#[derive(Debug, Clone, PartialEq)]
16pub enum FeedbackAction {
17 Correct,
18 Delete,
19 Add,
20 NegativePreference,
21}
22
23#[derive(Debug, Clone)]
25pub struct FeedbackResult {
26 pub action: FeedbackAction,
27 pub category: Option<MemoryCategory>,
28 pub new_content: Option<String>,
29 pub search_keywords: Vec<String>,
30 pub original_text: String,
31}
32
33pub fn detect_feedback_patterns(text: &str) -> Vec<FeedbackResult> {
35 let mut results = Vec::new();
36 let text_lower = text.to_lowercase();
37
38 let correction_patterns = [
39 "不对,应该是",
40 "错了,实际上",
41 "不是,是",
42 "应该是",
43 "no, it should be",
44 "wrong, actually",
45 "should be",
46 ];
47
48 let delete_patterns = [
49 "不要那个",
50 "不需要那个",
51 "删掉那个",
52 "不再用",
53 "don't need that",
54 "no longer need",
55 "remove that",
56 ];
57
58 let add_patterns = [
59 "记一下",
60 "记住",
61 "记录一下",
62 "要记住",
63 "remember this",
64 "note this",
65 "keep this",
66 ];
67
68 let negative_patterns = [
69 "不喜欢",
70 "不偏好",
71 "讨厌",
72 "不想用",
73 "i don't like",
74 "i dislike",
75 "i hate",
76 ];
77
78 for pattern in correction_patterns {
79 if text_lower.contains(pattern) {
80 let content = extract_feedback_content(text, pattern);
81 if content.len() >= MIN_MEMORY_CONTENT_LENGTH {
82 results.push(FeedbackResult {
83 action: FeedbackAction::Correct,
84 category: Some(infer_category_from_content(&content)),
85 new_content: Some(content.clone()),
86 search_keywords: extract_context_keywords(&content),
87 original_text: text.to_string(),
88 });
89 }
90 }
91 }
92
93 for pattern in delete_patterns {
94 if text_lower.contains(pattern) {
95 let content = extract_feedback_content(text, pattern);
96 results.push(FeedbackResult {
97 action: FeedbackAction::Delete,
98 category: None,
99 new_content: None,
100 search_keywords: if content.is_empty() {
101 vec![pattern.to_string()]
102 } else {
103 extract_context_keywords(&content)
104 },
105 original_text: text.to_string(),
106 });
107 }
108 }
109
110 for pattern in add_patterns {
111 if text_lower.contains(pattern) {
112 let content = extract_feedback_content(text, pattern);
113 if content.len() >= MIN_MEMORY_CONTENT_LENGTH {
114 results.push(FeedbackResult {
115 action: FeedbackAction::Add,
116 category: Some(infer_category_from_content(&content)),
117 new_content: Some(content),
118 search_keywords: vec![],
119 original_text: text.to_string(),
120 });
121 }
122 }
123 }
124
125 for pattern in negative_patterns {
126 if text_lower.contains(pattern) {
127 let content = extract_feedback_content(text, pattern);
128 if content.len() >= MIN_MEMORY_CONTENT_LENGTH {
129 results.push(FeedbackResult {
130 action: FeedbackAction::NegativePreference,
131 category: Some(MemoryCategory::Preference),
132 new_content: Some(format!("不喜欢: {}", content)),
133 search_keywords: extract_context_keywords(&content),
134 original_text: text.to_string(),
135 });
136 }
137 }
138 }
139
140 results
141}
142
143fn extract_feedback_content(text: &str, pattern: &str) -> String {
144 let pos = match text.to_lowercase().find(&pattern.to_lowercase()) {
145 Some(p) => p,
146 None => return String::new(),
147 };
148
149 let start = pos + pattern.len();
150 if start >= text.len() {
151 return String::new();
152 }
153
154 let remaining = &text[start..];
155 let end = remaining
156 .find(['.', '。', '\n'])
157 .unwrap_or(remaining.len().min(100));
158
159 remaining[..end].trim().to_string()
160}
161
162pub fn apply_feedback_to_memory(memory: &mut AutoMemory, feedback: &FeedbackResult) -> usize {
164 let mut changes = 0;
165
166 match feedback.action {
167 FeedbackAction::Correct => {
168 if let Some(ref content) = feedback.new_content {
169 for entry in &mut memory.entries {
171 if feedback
172 .search_keywords
173 .iter()
174 .any(|k| entry.content.to_lowercase().contains(&k.to_lowercase()))
175 {
176 entry.content = content.clone();
177 entry.importance = entry.importance.max(80.0);
178 changes += 1;
179 }
180 }
181 if changes == 0 {
182 let category = feedback.category.unwrap_or(MemoryCategory::Finding);
184 memory.add_memory(category, content.clone(), None);
185 changes += 1;
186 }
187 }
188 }
189 FeedbackAction::Delete => {
190 let ids_to_delete: Vec<String> = memory
191 .entries
192 .iter()
193 .filter(|e| {
194 feedback
195 .search_keywords
196 .iter()
197 .any(|k| e.content.to_lowercase().contains(&k.to_lowercase()))
198 })
199 .take(3)
200 .map(|e| e.id.clone())
201 .collect();
202
203 for id in ids_to_delete {
204 if memory.remove(&id) {
205 changes += 1;
206 }
207 }
208 }
209 FeedbackAction::Add => {
210 if let Some(ref content) = feedback.new_content {
211 let category = feedback.category.unwrap_or(MemoryCategory::Finding);
212 let entry = MemoryEntry::manual(category, content.clone());
213 memory.add(entry);
214 changes += 1;
215 }
216 }
217 FeedbackAction::NegativePreference => {
218 if let Some(ref content) = feedback.new_content {
219 let mut entry = MemoryEntry::manual(MemoryCategory::Preference, content.clone());
220 entry.tags.push("negative".to_string());
221 memory.add(entry);
222 changes += 1;
223 }
224 }
225 }
226
227 changes
228}
229
230#[derive(Clone)]
236pub struct BehaviorInferenceConfig {
237 pub min_occurrences: usize,
238 pub min_confidence: f64,
239 pub max_inferences: usize,
240}
241
242impl Default for BehaviorInferenceConfig {
243 fn default() -> Self {
244 Self {
245 min_occurrences: 2,
246 min_confidence: 0.6,
247 max_inferences: 5,
248 }
249 }
250}
251
252#[derive(Debug, Clone)]
254pub struct BehaviorInference {
255 pub content: String,
256 pub confidence: f64,
257 pub occurrences: usize,
258 pub keywords: Vec<String>,
259}
260
261pub fn infer_preferences_from_behavior(
263 messages: &[crate::providers::Message],
264 config: &BehaviorInferenceConfig,
265) -> Vec<BehaviorInference> {
266 let mut inferences: Vec<BehaviorInference> = Vec::new();
267
268 let user_texts: Vec<String> = messages
269 .iter()
270 .filter_map(|msg| {
271 if msg.role == crate::providers::Role::User {
272 match &msg.content {
273 crate::providers::MessageContent::Text(t) => Some(t.clone()),
274 crate::providers::MessageContent::Blocks(blocks) => Some(
275 blocks
276 .iter()
277 .filter_map(|b| {
278 if let crate::providers::ContentBlock::Text { text } = b {
279 Some(text.as_str())
280 } else {
281 None
282 }
283 })
284 .collect::<Vec<_>>()
285 .join(" "),
286 ),
287 }
288 } else {
289 None
290 }
291 })
292 .collect();
293
294 if user_texts.len() < config.min_occurrences {
295 return inferences;
296 }
297
298 let all_text = user_texts.join(" ");
299 let all_text_lower = all_text.to_lowercase();
300
301 let tech_patterns: Vec<(&str, &str)> = vec![
302 ("rust", "Rust"),
303 ("python", "Python"),
304 ("react", "React"),
305 ("vue", "Vue"),
306 ("typescript", "TypeScript"),
307 ("go", "Go"),
308 ("docker", "Docker"),
309 ("postgres", "PostgreSQL"),
310 ("vim", "Vim"),
311 ];
312
313 let mut tech_counts: HashMap<&str, usize> = HashMap::new();
314 for (pattern, _) in &tech_patterns {
315 let count = all_text_lower.matches(pattern).count();
316 if count >= config.min_occurrences {
317 tech_counts.insert(pattern, count);
318 }
319 }
320
321 for (pattern, name) in tech_patterns {
322 if let Some(&count) = tech_counts.get(pattern) {
323 let confidence = (count as f64 / user_texts.len() as f64).min(1.0);
324 if confidence >= config.min_confidence {
325 inferences.push(BehaviorInference {
326 content: format!("用户频繁提及 {}", name),
327 confidence,
328 occurrences: count,
329 keywords: vec![name.to_string()],
330 });
331 }
332 }
333 }
334
335 inferences.truncate(config.max_inferences);
336 inferences
337}
338
339pub fn inference_to_memory_entry(inference: &BehaviorInference) -> MemoryEntry {
341 let mut entry = MemoryEntry::new(MemoryCategory::Preference, inference.content.clone(), None);
342 entry.importance = (inference.confidence * 70.0 + 30.0).min(80.0);
343 entry.tags = inference.keywords.clone();
344 entry
345}
346
347pub fn apply_behavior_inferences_to_memory(
350 messages: &[crate::providers::Message],
351 memory: &mut AutoMemory,
352 config: Option<&BehaviorInferenceConfig>,
353) -> usize {
354 let cfg = config.cloned().unwrap_or_default();
355 let inferences = infer_preferences_from_behavior(messages, &cfg);
356
357 let mut added = 0;
358 for inference in inferences {
359 let entry = inference_to_memory_entry(&inference);
360 if !memory.entries.iter().any(|e| e.content == entry.content) {
362 memory.entries.push(entry);
363 added += 1;
364 }
365 }
366
367 added
368}