Skip to main content

matrixcode_core/memory/
smart_retrieval.rs

1//! Smart Memory Retrieval: Advanced retrieval with focus relevance and time decay.
2//!
3//! This module implements intelligent memory retrieval using multiple factors:
4//! - TF-IDF similarity
5//! - Focus point relevance
6//! - Time decay
7//! - Importance weighting
8//! - Usage frequency
9
10use chrono::{DateTime, Utc};
11use std::collections::HashMap;
12
13use super::entry::{MemoryCategory, MemoryEntry};
14use super::retrieval::{TfIdfSearch, compute_relevance, expand_semantic_keywords};
15use crate::compress::{FocusPoint, FocusStatus};
16
17/// Smart memory retriever with advanced scoring.
18pub struct SmartMemoryRetriever {
19    /// Time decay configuration
20    time_decay_config: TimeDecayConfig,
21    /// Focus relevance weight
22    focus_weight: f32,
23    /// Time decay weight
24    time_decay_weight: f32,
25    /// Importance weight
26    importance_weight: f32,
27    /// TF-IDF weight
28    tfidf_weight: f32,
29    /// Usage frequency weight
30    usage_weight: f32,
31}
32
33/// Time decay configuration.
34#[derive(Debug, Clone)]
35pub struct TimeDecayConfig {
36    /// Half-life in hours (time for relevance to decay to 50%)
37    half_life_hours: f32,
38    /// Minimum decay factor (don't decay below this)
39    min_decay: f32,
40    /// Boost for recent entries (< 1 hour)
41    recent_boost: f32,
42}
43
44impl Default for TimeDecayConfig {
45    fn default() -> Self {
46        Self {
47            half_life_hours: 24.0, // Decay to 50% after 24 hours
48            min_decay: 0.3,        // Keep at least 30% relevance
49            recent_boost: 1.5,     // Boost recent entries by 50%
50        }
51    }
52}
53
54impl Default for SmartMemoryRetriever {
55    fn default() -> Self {
56        Self {
57            time_decay_config: TimeDecayConfig::default(),
58            focus_weight: 0.25,
59            time_decay_weight: 0.15,
60            importance_weight: 0.20,
61            tfidf_weight: 0.30,
62            usage_weight: 0.10,
63        }
64    }
65}
66
67impl SmartMemoryRetriever {
68    /// Create a new smart retriever.
69    pub fn new() -> Self {
70        Self::default()
71    }
72
73    /// Create with custom weights.
74    pub fn with_weights(
75        focus_weight: f32,
76        time_decay_weight: f32,
77        importance_weight: f32,
78        tfidf_weight: f32,
79        usage_weight: f32,
80    ) -> Self {
81        // Normalize weights to sum to 1.0
82        let total = focus_weight + time_decay_weight + importance_weight + tfidf_weight + usage_weight;
83        Self {
84            time_decay_config: TimeDecayConfig::default(),
85            focus_weight: focus_weight / total,
86            time_decay_weight: time_decay_weight / total,
87            importance_weight: importance_weight / total,
88            tfidf_weight: tfidf_weight / total,
89            usage_weight: usage_weight / total,
90        }
91    }
92
93    /// Retrieve memories with smart scoring.
94    pub fn retrieve(
95        &self,
96        entries: &[MemoryEntry],
97        context_keywords: &[String],
98        active_foci: &[FocusPoint],
99        max_entries: usize,
100    ) -> Vec<MemoryEntry> {
101        if entries.is_empty() {
102            return Vec::new();
103        }
104
105        // Expand keywords semantically
106        let expanded_keywords = expand_semantic_keywords(context_keywords);
107
108        // Build TF-IDF index
109        let mut tfidf = TfIdfSearch::new();
110        // Note: We need to adapt TfIdfSearch to work with entries directly
111        // For now, we'll use a simplified approach
112
113        // Calculate scores for each entry
114        let mut scored_entries: Vec<(MemoryEntry, f32)> = entries
115            .iter()
116            .map(|entry| {
117                let score = self.calculate_entry_score(
118                    entry,
119                    &expanded_keywords,
120                    active_foci,
121                    Utc::now(),
122                );
123                (entry.clone(), score)
124            })
125            .collect();
126
127        // Sort by score (descending)
128        scored_entries.sort_by(|a, b| {
129            b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)
130        });
131
132        // Take top entries
133        scored_entries
134            .into_iter()
135            .take(max_entries)
136            .map(|(entry, _)| entry)
137            .collect()
138    }
139
140    /// Calculate comprehensive score for an entry.
141    fn calculate_entry_score(
142        &self,
143        entry: &MemoryEntry,
144        keywords: &[String],
145        active_foci: &[FocusPoint],
146        now: DateTime<Utc>,
147    ) -> f32 {
148        // 1. TF-IDF / keyword relevance
149        let relevance_score = self.calculate_relevance_score(entry, keywords);
150
151        // 2. Focus relevance
152        let focus_score = self.calculate_focus_relevance(entry, active_foci);
153
154        // 3. Time decay
155        let time_score = self.calculate_time_decay(entry, now);
156
157        // 4. Importance
158        let importance_score = entry.importance as f32 / 100.0;
159
160        // 5. Usage frequency
161        let usage_score = self.calculate_usage_score(entry);
162
163        // Combine scores with weights
164        let combined = 
165            relevance_score * self.tfidf_weight +
166            focus_score * self.focus_weight +
167            time_score * self.time_decay_weight +
168            importance_score * self.importance_weight +
169            usage_score * self.usage_weight;
170
171        // Apply manual boost
172        if entry.is_manual {
173            combined * 1.5
174        } else {
175            combined
176        }
177    }
178
179    /// Calculate keyword relevance score.
180    fn calculate_relevance_score(&self, entry: &MemoryEntry, keywords: &[String]) -> f32 {
181        let entry_lower = entry.content.to_lowercase();
182        let mut score = 0.0;
183
184        for keyword in keywords {
185            let kw_lower = keyword.to_lowercase();
186            if entry_lower.contains(&kw_lower) {
187                // Exact match: full score
188                score += 1.0;
189            } else {
190                // Check tags
191                if entry.tags.iter().any(|t| t.to_lowercase().contains(&kw_lower)) {
192                    score += 0.7;
193                }
194            }
195        }
196
197        // Normalize by number of keywords
198        if !keywords.is_empty() {
199            (score / keywords.len() as f32).min(1.0)
200        } else {
201            0.0
202        }
203    }
204
205    /// Calculate focus relevance score.
206    fn calculate_focus_relevance(&self, entry: &MemoryEntry, active_foci: &[FocusPoint]) -> f32 {
207        if active_foci.is_empty() {
208            return 0.5; // Neutral if no active focus
209        }
210
211        let entry_lower = entry.content.to_lowercase();
212        let mut max_score: f32 = 0.0;
213
214        for focus in active_foci {
215            let mut score: f32 = 0.0;
216
217            // Keyword match
218            for keyword in &focus.keywords {
219                if entry_lower.contains(&keyword.to_lowercase()) {
220                    score += 0.3;
221                }
222            }
223
224            // Entity match (higher weight)
225            for entity in &focus.entities {
226                if entry_lower.contains(&entity.to_lowercase()) {
227                    score += 0.5;
228                }
229            }
230
231            // Topic overlap
232            let topic_words = focus.topic.split_whitespace()
233                .map(|w| w.to_lowercase())
234                .collect::<Vec<_>>();
235            for word in &topic_words {
236                if entry_lower.contains(word) {
237                    score += 0.2;
238                }
239            }
240
241            // Weight by focus importance
242            score *= focus.importance;
243
244            // Weight by focus status
245            if focus.status == FocusStatus::Active {
246                score *= 1.2;
247            }
248
249            max_score = max_score.max(score);
250        }
251
252        max_score.min(1.0_f32)
253    }
254
255    /// Calculate time decay score.
256    fn calculate_time_decay(&self, entry: &MemoryEntry, now: DateTime<Utc>) -> f32 {
257        let hours_since_created = (now - entry.created_at).num_seconds() as f32 / 3600.0;
258
259        // Apply exponential decay
260        let decay_factor = 0.5_f32.powf(hours_since_created / self.time_decay_config.half_life_hours);
261
262        // Apply minimum threshold
263        let decayed = decay_factor.max(self.time_decay_config.min_decay);
264
265        // Apply recent boost
266        if hours_since_created < 1.0 {
267            decayed * self.time_decay_config.recent_boost
268        } else {
269            decayed
270        }
271    }
272
273    /// Calculate usage frequency score.
274    fn calculate_usage_score(&self, entry: &MemoryEntry) -> f32 {
275        // Normalize reference count
276        let ref_score = (entry.reference_count as f32 / 10.0).min(1.0);
277
278        // New entries have neutral score
279        if entry.reference_count == 0 {
280            0.5
281        } else {
282            ref_score
283        }
284    }
285
286    /// Generate smart summary with focus awareness.
287    pub fn generate_smart_summary(
288        &self,
289        entries: &[MemoryEntry],
290        context_keywords: &[String],
291        active_foci: &[FocusPoint],
292        max_entries: usize,
293    ) -> String {
294        let selected = self.retrieve(entries, context_keywords, active_foci, max_entries);
295
296        if selected.is_empty() {
297            return String::new();
298        }
299
300        let mut summary = String::from("【智能记忆检索】\n\n");
301
302        // Group by category
303        let mut by_cat: HashMap<MemoryCategory, Vec<&MemoryEntry>> = HashMap::new();
304        for entry in &selected {
305            by_cat.entry(entry.category).or_default().push(entry);
306        }
307
308        // Add focus context if available
309        if !active_foci.is_empty() {
310            summary.push_str("当前聚焦:\n");
311            for focus in active_foci {
312                summary.push_str(&format!("  • {} (重要性: {:.0}%)\n", focus.topic, focus.importance * 100.0));
313            }
314            summary.push_str("\n");
315        }
316
317        // Add memories by category
318        for (cat, entries) in by_cat {
319            summary.push_str(&format!("{} {}:\n", cat.icon(), cat.display_name()));
320            for entry in entries {
321                summary.push_str(&format!("  {}\n", entry.format_for_prompt()));
322            }
323            summary.push_str("\n");
324        }
325
326        summary
327    }
328
329    /// Get retrieval statistics.
330    pub fn get_retrieval_stats(
331        &self,
332        entries: &[MemoryEntry],
333        context_keywords: &[String],
334        active_foci: &[FocusPoint],
335    ) -> RetrievalStats {
336        let expanded = expand_semantic_keywords(context_keywords);
337
338        let mut stats = RetrievalStats {
339            total_entries: entries.len(),
340            keyword_matches: 0,
341            focus_matches: 0,
342            recent_entries: 0,
343            highly_important: 0,
344            frequently_used: 0,
345            avg_score: 0.0,
346        };
347
348        let now = Utc::now();
349        let mut total_score = 0.0;
350
351        for entry in entries {
352            let score = self.calculate_entry_score(entry, &expanded, active_foci, now);
353            total_score += score;
354
355            // Count matches
356            if self.calculate_relevance_score(entry, &expanded) > 0.5 {
357                stats.keyword_matches += 1;
358            }
359            if self.calculate_focus_relevance(entry, active_foci) > 0.5 {
360                stats.focus_matches += 1;
361            }
362            if (now - entry.created_at).num_hours() < 1 {
363                stats.recent_entries += 1;
364            }
365            if entry.importance > 70.0 {
366                stats.highly_important += 1;
367            }
368            if entry.reference_count > 5 {
369                stats.frequently_used += 1;
370            }
371        }
372
373        if !entries.is_empty() {
374            stats.avg_score = total_score / entries.len() as f32;
375        }
376
377        stats
378    }
379}
380
381/// Retrieval statistics.
382#[derive(Debug, Clone)]
383pub struct RetrievalStats {
384    pub total_entries: usize,
385    pub keyword_matches: usize,
386    pub focus_matches: usize,
387    pub recent_entries: usize,
388    pub highly_important: usize,
389    pub frequently_used: usize,
390    pub avg_score: f32,
391}
392
393#[cfg(test)]
394mod tests {
395    use super::*;
396
397    #[test]
398    fn test_smart_retriever_creation() {
399        let retriever = SmartMemoryRetriever::new();
400        assert_eq!(retriever.focus_weight + retriever.time_decay_weight + retriever.importance_weight + retriever.tfidf_weight + retriever.usage_weight, 1.0);
401    }
402
403    #[test]
404    fn test_time_decay_config() {
405        let config = TimeDecayConfig::default();
406        assert_eq!(config.half_life_hours, 24.0);
407        assert_eq!(config.min_decay, 0.3);
408    }
409
410    #[test]
411    fn test_empty_retrieval() {
412        let retriever = SmartMemoryRetriever::new();
413        let result = retriever.retrieve(&[], &[], &[], 5);
414        assert!(result.is_empty());
415    }
416
417    #[test]
418    fn test_empty_summary() {
419        let retriever = SmartMemoryRetriever::new();
420        let summary = retriever.generate_smart_summary(&[], &[], &[], 5);
421        assert!(summary.is_empty());
422    }
423}