Skip to main content

matrixcode_core/memory/
learning.rs

1//! Feedback learning and behavior inference.
2
3use std::collections::HashMap;
4
5use super::config::MIN_MEMORY_CONTENT_LENGTH;
6use super::extractor::infer_category_from_content;
7use super::retrieval::extract_context_keywords;
8use super::types::{AutoMemory, MemoryCategory, MemoryEntry};
9
10// ============================================================================
11// Feedback Detection
12// ============================================================================
13
14/// Action to take when user feedback is detected.
15#[derive(Debug, Clone, PartialEq)]
16pub enum FeedbackAction {
17    Correct,
18    Delete,
19    Add,
20    NegativePreference,
21}
22
23/// Result of feedback detection.
24#[derive(Debug, Clone)]
25pub struct FeedbackResult {
26    pub action: FeedbackAction,
27    pub category: Option<MemoryCategory>,
28    pub new_content: Option<String>,
29    pub search_keywords: Vec<String>,
30    pub original_text: String,
31}
32
33/// Detect user feedback patterns.
34pub fn detect_feedback_patterns(text: &str) -> Vec<FeedbackResult> {
35    let mut results = Vec::new();
36    let text_lower = text.to_lowercase();
37
38    let correction_patterns = [
39        "不对,应该是",
40        "错了,实际上",
41        "不是,是",
42        "应该是",
43        "no, it should be",
44        "wrong, actually",
45        "should be",
46    ];
47
48    let delete_patterns = [
49        "不要那个",
50        "不需要那个",
51        "删掉那个",
52        "不再用",
53        "don't need that",
54        "no longer need",
55        "remove that",
56    ];
57
58    let add_patterns = [
59        "记一下",
60        "记住",
61        "记录一下",
62        "要记住",
63        "remember this",
64        "note this",
65        "keep this",
66    ];
67
68    let negative_patterns = [
69        "不喜欢",
70        "不偏好",
71        "讨厌",
72        "不想用",
73        "i don't like",
74        "i dislike",
75        "i hate",
76    ];
77
78    for pattern in correction_patterns {
79        if text_lower.contains(pattern) {
80            let content = extract_feedback_content(text, pattern);
81            if content.len() >= MIN_MEMORY_CONTENT_LENGTH {
82                results.push(FeedbackResult {
83                    action: FeedbackAction::Correct,
84                    category: Some(infer_category_from_content(&content)),
85                    new_content: Some(content.clone()),
86                    search_keywords: extract_context_keywords(&content),
87                    original_text: text.to_string(),
88                });
89            }
90        }
91    }
92
93    for pattern in delete_patterns {
94        if text_lower.contains(pattern) {
95            let content = extract_feedback_content(text, pattern);
96            results.push(FeedbackResult {
97                action: FeedbackAction::Delete,
98                category: None,
99                new_content: None,
100                search_keywords: if content.is_empty() {
101                    vec![pattern.to_string()]
102                } else {
103                    extract_context_keywords(&content)
104                },
105                original_text: text.to_string(),
106            });
107        }
108    }
109
110    for pattern in add_patterns {
111        if text_lower.contains(pattern) {
112            let content = extract_feedback_content(text, pattern);
113            if content.len() >= MIN_MEMORY_CONTENT_LENGTH {
114                results.push(FeedbackResult {
115                    action: FeedbackAction::Add,
116                    category: Some(infer_category_from_content(&content)),
117                    new_content: Some(content),
118                    search_keywords: vec![],
119                    original_text: text.to_string(),
120                });
121            }
122        }
123    }
124
125    for pattern in negative_patterns {
126        if text_lower.contains(pattern) {
127            let content = extract_feedback_content(text, pattern);
128            if content.len() >= MIN_MEMORY_CONTENT_LENGTH {
129                results.push(FeedbackResult {
130                    action: FeedbackAction::NegativePreference,
131                    category: Some(MemoryCategory::Preference),
132                    new_content: Some(format!("不喜欢: {}", content)),
133                    search_keywords: extract_context_keywords(&content),
134                    original_text: text.to_string(),
135                });
136            }
137        }
138    }
139
140    results
141}
142
143fn extract_feedback_content(text: &str, pattern: &str) -> String {
144    let pos = match text.to_lowercase().find(&pattern.to_lowercase()) {
145        Some(p) => p,
146        None => return String::new(),
147    };
148
149    let start = pos + pattern.len();
150    if start >= text.len() {
151        return String::new();
152    }
153
154    let remaining = &text[start..];
155    let end = remaining
156        .find(['.', '。', '\n'])
157        .unwrap_or(remaining.len().min(100));
158
159    remaining[..end].trim().to_string()
160}
161
162/// Apply feedback to memory.
163pub fn apply_feedback_to_memory(memory: &mut AutoMemory, feedback: &FeedbackResult) -> usize {
164    let mut changes = 0;
165
166    match feedback.action {
167        FeedbackAction::Correct => {
168            if let Some(ref content) = feedback.new_content {
169                // Find matching entries and update
170                for entry in &mut memory.entries {
171                    if feedback
172                        .search_keywords
173                        .iter()
174                        .any(|k| entry.content.to_lowercase().contains(&k.to_lowercase()))
175                    {
176                        entry.content = content.clone();
177                        entry.importance = entry.importance.max(80.0);
178                        changes += 1;
179                    }
180                }
181                if changes == 0 {
182                    // No matching entry, add new
183                    let category = feedback.category.unwrap_or(MemoryCategory::Finding);
184                    memory.add_memory(category, content.clone(), None);
185                    changes += 1;
186                }
187            }
188        }
189        FeedbackAction::Delete => {
190            let ids_to_delete: Vec<String> = memory
191                .entries
192                .iter()
193                .filter(|e| {
194                    feedback
195                        .search_keywords
196                        .iter()
197                        .any(|k| e.content.to_lowercase().contains(&k.to_lowercase()))
198                })
199                .take(3)
200                .map(|e| e.id.clone())
201                .collect();
202
203            for id in ids_to_delete {
204                if memory.remove(&id) {
205                    changes += 1;
206                }
207            }
208        }
209        FeedbackAction::Add => {
210            if let Some(ref content) = feedback.new_content {
211                let category = feedback.category.unwrap_or(MemoryCategory::Finding);
212                let entry = MemoryEntry::manual(category, content.clone());
213                memory.add(entry);
214                changes += 1;
215            }
216        }
217        FeedbackAction::NegativePreference => {
218            if let Some(ref content) = feedback.new_content {
219                let mut entry = MemoryEntry::manual(MemoryCategory::Preference, content.clone());
220                entry.tags.push("negative".to_string());
221                memory.add(entry);
222                changes += 1;
223            }
224        }
225    }
226
227    changes
228}
229
230// ============================================================================
231// Behavior Inference
232// ============================================================================
233
234/// Configuration for behavior inference.
235#[derive(Clone)]
236pub struct BehaviorInferenceConfig {
237    pub min_occurrences: usize,
238    pub min_confidence: f64,
239    pub max_inferences: usize,
240}
241
242impl Default for BehaviorInferenceConfig {
243    fn default() -> Self {
244        Self {
245            min_occurrences: 2,
246            min_confidence: 0.6,
247            max_inferences: 5,
248        }
249    }
250}
251
252/// Result of behavior inference.
253#[derive(Debug, Clone)]
254pub struct BehaviorInference {
255    pub content: String,
256    pub confidence: f64,
257    pub occurrences: usize,
258    pub keywords: Vec<String>,
259}
260
261/// Infer preferences from conversation patterns.
262pub fn infer_preferences_from_behavior(
263    messages: &[crate::providers::Message],
264    config: &BehaviorInferenceConfig,
265) -> Vec<BehaviorInference> {
266    let mut inferences: Vec<BehaviorInference> = Vec::new();
267
268    let user_texts: Vec<String> = messages
269        .iter()
270        .filter_map(|msg| {
271            if msg.role == crate::providers::Role::User {
272                match &msg.content {
273                    crate::providers::MessageContent::Text(t) => Some(t.clone()),
274                    crate::providers::MessageContent::Blocks(blocks) => Some(
275                        blocks
276                            .iter()
277                            .filter_map(|b| {
278                                if let crate::providers::ContentBlock::Text { text } = b {
279                                    Some(text.as_str())
280                                } else {
281                                    None
282                                }
283                            })
284                            .collect::<Vec<_>>()
285                            .join(" "),
286                    ),
287                }
288            } else {
289                None
290            }
291        })
292        .collect();
293
294    if user_texts.len() < config.min_occurrences {
295        return inferences;
296    }
297
298    let all_text = user_texts.join(" ");
299    let all_text_lower = all_text.to_lowercase();
300
301    let tech_patterns: Vec<(&str, &str)> = vec![
302        ("rust", "Rust"),
303        ("python", "Python"),
304        ("react", "React"),
305        ("vue", "Vue"),
306        ("typescript", "TypeScript"),
307        ("go", "Go"),
308        ("docker", "Docker"),
309        ("postgres", "PostgreSQL"),
310        ("vim", "Vim"),
311    ];
312
313    let mut tech_counts: HashMap<&str, usize> = HashMap::new();
314    for (pattern, _) in &tech_patterns {
315        let count = all_text_lower.matches(pattern).count();
316        if count >= config.min_occurrences {
317            tech_counts.insert(pattern, count);
318        }
319    }
320
321    for (pattern, name) in tech_patterns {
322        if let Some(&count) = tech_counts.get(pattern) {
323            let confidence = (count as f64 / user_texts.len() as f64).min(1.0);
324            if confidence >= config.min_confidence {
325                inferences.push(BehaviorInference {
326                    content: format!("用户频繁提及 {}", name),
327                    confidence,
328                    occurrences: count,
329                    keywords: vec![name.to_string()],
330                });
331            }
332        }
333    }
334
335    inferences.truncate(config.max_inferences);
336    inferences
337}
338
339/// Convert inference to memory entry.
340pub fn inference_to_memory_entry(inference: &BehaviorInference) -> MemoryEntry {
341    let mut entry = MemoryEntry::new(MemoryCategory::Preference, inference.content.clone(), None);
342    entry.importance = (inference.confidence * 70.0 + 30.0).min(80.0);
343    entry.tags = inference.keywords.clone();
344    entry
345}
346
347/// Apply behavior inferences to memory.
348/// Returns the number of new entries added.
349pub fn apply_behavior_inferences_to_memory(
350    messages: &[crate::providers::Message],
351    memory: &mut AutoMemory,
352    config: Option<&BehaviorInferenceConfig>,
353) -> usize {
354    let cfg = config.cloned().unwrap_or_default();
355    let inferences = infer_preferences_from_behavior(messages, &cfg);
356
357    let mut added = 0;
358    for inference in inferences {
359        let entry = inference_to_memory_entry(&inference);
360        // Check if similar entry already exists
361        if !memory.entries.iter().any(|e| e.content == entry.content) {
362            memory.entries.push(entry);
363            added += 1;
364        }
365    }
366
367    added
368}