1use chrono::{DateTime, Utc};
11use std::collections::HashMap;
12
13use super::entry::{MemoryCategory, MemoryEntry};
14use super::retrieval::{TfIdfSearch, compute_relevance, expand_semantic_keywords};
15use crate::compress::{FocusPoint, FocusStatus};
16
17pub struct SmartMemoryRetriever {
19 time_decay_config: TimeDecayConfig,
21 focus_weight: f32,
23 time_decay_weight: f32,
25 importance_weight: f32,
27 tfidf_weight: f32,
29 usage_weight: f32,
31}
32
33#[derive(Debug, Clone)]
35pub struct TimeDecayConfig {
36 half_life_hours: f32,
38 min_decay: f32,
40 recent_boost: f32,
42}
43
44impl Default for TimeDecayConfig {
45 fn default() -> Self {
46 Self {
47 half_life_hours: 24.0, min_decay: 0.3, recent_boost: 1.5, }
51 }
52}
53
54impl Default for SmartMemoryRetriever {
55 fn default() -> Self {
56 Self {
57 time_decay_config: TimeDecayConfig::default(),
58 focus_weight: 0.25,
59 time_decay_weight: 0.15,
60 importance_weight: 0.20,
61 tfidf_weight: 0.30,
62 usage_weight: 0.10,
63 }
64 }
65}
66
67impl SmartMemoryRetriever {
68 pub fn new() -> Self {
70 Self::default()
71 }
72
73 pub fn with_weights(
75 focus_weight: f32,
76 time_decay_weight: f32,
77 importance_weight: f32,
78 tfidf_weight: f32,
79 usage_weight: f32,
80 ) -> Self {
81 let total = focus_weight + time_decay_weight + importance_weight + tfidf_weight + usage_weight;
83 Self {
84 time_decay_config: TimeDecayConfig::default(),
85 focus_weight: focus_weight / total,
86 time_decay_weight: time_decay_weight / total,
87 importance_weight: importance_weight / total,
88 tfidf_weight: tfidf_weight / total,
89 usage_weight: usage_weight / total,
90 }
91 }
92
93 pub fn retrieve(
95 &self,
96 entries: &[MemoryEntry],
97 context_keywords: &[String],
98 active_foci: &[FocusPoint],
99 max_entries: usize,
100 ) -> Vec<MemoryEntry> {
101 if entries.is_empty() {
102 return Vec::new();
103 }
104
105 let expanded_keywords = expand_semantic_keywords(context_keywords);
107
108 let mut tfidf = TfIdfSearch::new();
110 let mut scored_entries: Vec<(MemoryEntry, f32)> = entries
115 .iter()
116 .map(|entry| {
117 let score = self.calculate_entry_score(
118 entry,
119 &expanded_keywords,
120 active_foci,
121 Utc::now(),
122 );
123 (entry.clone(), score)
124 })
125 .collect();
126
127 scored_entries.sort_by(|a, b| {
129 b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)
130 });
131
132 scored_entries
134 .into_iter()
135 .take(max_entries)
136 .map(|(entry, _)| entry)
137 .collect()
138 }
139
140 fn calculate_entry_score(
142 &self,
143 entry: &MemoryEntry,
144 keywords: &[String],
145 active_foci: &[FocusPoint],
146 now: DateTime<Utc>,
147 ) -> f32 {
148 let relevance_score = self.calculate_relevance_score(entry, keywords);
150
151 let focus_score = self.calculate_focus_relevance(entry, active_foci);
153
154 let time_score = self.calculate_time_decay(entry, now);
156
157 let importance_score = entry.importance as f32 / 100.0;
159
160 let usage_score = self.calculate_usage_score(entry);
162
163 let combined =
165 relevance_score * self.tfidf_weight +
166 focus_score * self.focus_weight +
167 time_score * self.time_decay_weight +
168 importance_score * self.importance_weight +
169 usage_score * self.usage_weight;
170
171 if entry.is_manual {
173 combined * 1.5
174 } else {
175 combined
176 }
177 }
178
179 fn calculate_relevance_score(&self, entry: &MemoryEntry, keywords: &[String]) -> f32 {
181 let entry_lower = entry.content.to_lowercase();
182 let mut score = 0.0;
183
184 for keyword in keywords {
185 let kw_lower = keyword.to_lowercase();
186 if entry_lower.contains(&kw_lower) {
187 score += 1.0;
189 } else {
190 if entry.tags.iter().any(|t| t.to_lowercase().contains(&kw_lower)) {
192 score += 0.7;
193 }
194 }
195 }
196
197 if !keywords.is_empty() {
199 (score / keywords.len() as f32).min(1.0)
200 } else {
201 0.0
202 }
203 }
204
205 fn calculate_focus_relevance(&self, entry: &MemoryEntry, active_foci: &[FocusPoint]) -> f32 {
207 if active_foci.is_empty() {
208 return 0.5; }
210
211 let entry_lower = entry.content.to_lowercase();
212 let mut max_score: f32 = 0.0;
213
214 for focus in active_foci {
215 let mut score: f32 = 0.0;
216
217 for keyword in &focus.keywords {
219 if entry_lower.contains(&keyword.to_lowercase()) {
220 score += 0.3;
221 }
222 }
223
224 for entity in &focus.entities {
226 if entry_lower.contains(&entity.to_lowercase()) {
227 score += 0.5;
228 }
229 }
230
231 let topic_words = focus.topic.split_whitespace()
233 .map(|w| w.to_lowercase())
234 .collect::<Vec<_>>();
235 for word in &topic_words {
236 if entry_lower.contains(word) {
237 score += 0.2;
238 }
239 }
240
241 score *= focus.importance;
243
244 if focus.status == FocusStatus::Active {
246 score *= 1.2;
247 }
248
249 max_score = max_score.max(score);
250 }
251
252 max_score.min(1.0_f32)
253 }
254
255 fn calculate_time_decay(&self, entry: &MemoryEntry, now: DateTime<Utc>) -> f32 {
257 let hours_since_created = (now - entry.created_at).num_seconds() as f32 / 3600.0;
258
259 let decay_factor = 0.5_f32.powf(hours_since_created / self.time_decay_config.half_life_hours);
261
262 let decayed = decay_factor.max(self.time_decay_config.min_decay);
264
265 if hours_since_created < 1.0 {
267 decayed * self.time_decay_config.recent_boost
268 } else {
269 decayed
270 }
271 }
272
273 fn calculate_usage_score(&self, entry: &MemoryEntry) -> f32 {
275 let ref_score = (entry.reference_count as f32 / 10.0).min(1.0);
277
278 if entry.reference_count == 0 {
280 0.5
281 } else {
282 ref_score
283 }
284 }
285
286 pub fn generate_smart_summary(
288 &self,
289 entries: &[MemoryEntry],
290 context_keywords: &[String],
291 active_foci: &[FocusPoint],
292 max_entries: usize,
293 ) -> String {
294 let selected = self.retrieve(entries, context_keywords, active_foci, max_entries);
295
296 if selected.is_empty() {
297 return String::new();
298 }
299
300 let mut summary = String::from("【智能记忆检索】\n\n");
301
302 let mut by_cat: HashMap<MemoryCategory, Vec<&MemoryEntry>> = HashMap::new();
304 for entry in &selected {
305 by_cat.entry(entry.category).or_default().push(entry);
306 }
307
308 if !active_foci.is_empty() {
310 summary.push_str("当前聚焦:\n");
311 for focus in active_foci {
312 summary.push_str(&format!(" • {} (重要性: {:.0}%)\n", focus.topic, focus.importance * 100.0));
313 }
314 summary.push_str("\n");
315 }
316
317 for (cat, entries) in by_cat {
319 summary.push_str(&format!("{} {}:\n", cat.icon(), cat.display_name()));
320 for entry in entries {
321 summary.push_str(&format!(" {}\n", entry.format_for_prompt()));
322 }
323 summary.push_str("\n");
324 }
325
326 summary
327 }
328
329 pub fn get_retrieval_stats(
331 &self,
332 entries: &[MemoryEntry],
333 context_keywords: &[String],
334 active_foci: &[FocusPoint],
335 ) -> RetrievalStats {
336 let expanded = expand_semantic_keywords(context_keywords);
337
338 let mut stats = RetrievalStats {
339 total_entries: entries.len(),
340 keyword_matches: 0,
341 focus_matches: 0,
342 recent_entries: 0,
343 highly_important: 0,
344 frequently_used: 0,
345 avg_score: 0.0,
346 };
347
348 let now = Utc::now();
349 let mut total_score = 0.0;
350
351 for entry in entries {
352 let score = self.calculate_entry_score(entry, &expanded, active_foci, now);
353 total_score += score;
354
355 if self.calculate_relevance_score(entry, &expanded) > 0.5 {
357 stats.keyword_matches += 1;
358 }
359 if self.calculate_focus_relevance(entry, active_foci) > 0.5 {
360 stats.focus_matches += 1;
361 }
362 if (now - entry.created_at).num_hours() < 1 {
363 stats.recent_entries += 1;
364 }
365 if entry.importance > 70.0 {
366 stats.highly_important += 1;
367 }
368 if entry.reference_count > 5 {
369 stats.frequently_used += 1;
370 }
371 }
372
373 if !entries.is_empty() {
374 stats.avg_score = total_score / entries.len() as f32;
375 }
376
377 stats
378 }
379}
380
381#[derive(Debug, Clone)]
383pub struct RetrievalStats {
384 pub total_entries: usize,
385 pub keyword_matches: usize,
386 pub focus_matches: usize,
387 pub recent_entries: usize,
388 pub highly_important: usize,
389 pub frequently_used: usize,
390 pub avg_score: f32,
391}
392
393#[cfg(test)]
394mod tests {
395 use super::*;
396
397 #[test]
398 fn test_smart_retriever_creation() {
399 let retriever = SmartMemoryRetriever::new();
400 assert_eq!(retriever.focus_weight + retriever.time_decay_weight + retriever.importance_weight + retriever.tfidf_weight + retriever.usage_weight, 1.0);
401 }
402
403 #[test]
404 fn test_time_decay_config() {
405 let config = TimeDecayConfig::default();
406 assert_eq!(config.half_life_hours, 24.0);
407 assert_eq!(config.min_decay, 0.3);
408 }
409
410 #[test]
411 fn test_empty_retrieval() {
412 let retriever = SmartMemoryRetriever::new();
413 let result = retriever.retrieve(&[], &[], &[], 5);
414 assert!(result.is_empty());
415 }
416
417 #[test]
418 fn test_empty_summary() {
419 let retriever = SmartMemoryRetriever::new();
420 let summary = retriever.generate_smart_summary(&[], &[], &[], 5);
421 assert!(summary.is_empty());
422 }
423}