1use chrono::{DateTime, Utc};
10use serde::{Deserialize, Serialize};
11use std::collections::HashMap;
12
13pub struct AdaptiveLearner {
15 compression_feedback: Vec<CompressionFeedback>,
17 focus_feedback: Vec<FocusFeedback>,
19 retrieval_feedback: Vec<RetrievalFeedback>,
21 preferences: AdaptivePreferences,
23 stats: FeedbackStats,
25}
26
27#[derive(Debug, Clone, Serialize, Deserialize)]
29pub struct CompressionFeedback {
30 timestamp: DateTime<Utc>,
32 session_id: String,
34 original_tokens: u32,
36 compressed_tokens: u32,
38 stage: String,
40 rating: u8,
42 comments: Option<String>,
44 accepted: bool,
46}
47
48#[derive(Debug, Clone, Serialize, Deserialize)]
50pub struct FocusFeedback {
51 timestamp: DateTime<Utc>,
53 session_id: String,
55 focus_topic: String,
57 rating: u8,
59 accurate: bool,
61 suggested_correction: Option<String>,
63}
64
65#[derive(Debug, Clone, Serialize, Deserialize)]
67pub struct RetrievalFeedback {
68 timestamp: DateTime<Utc>,
70 session_id: String,
72 memory_id: String,
74 memory_content: String,
76 rating: u8,
78 relevant: bool,
80 suggested_context: Option<String>,
82}
83
84#[derive(Debug, Clone, Serialize, Deserialize)]
86pub struct AdaptivePreferences {
87 compression_aggressiveness: f32,
89 focus_sensitivity: f32,
91 preferred_stage: String,
93 retrieval_weights: HashMap<String, f32>,
95 category_preferences: HashMap<String, f32>,
97 last_updated: DateTime<Utc>,
99}
100
101impl Default for AdaptivePreferences {
102 fn default() -> Self {
103 Self {
104 compression_aggressiveness: 0.5,
105 focus_sensitivity: 0.7,
106 preferred_stage: "RemoveLowPriority".to_string(),
107 retrieval_weights: HashMap::new(),
108 category_preferences: HashMap::new(),
109 last_updated: Utc::now(),
110 }
111 }
112}
113
114#[derive(Debug, Clone, Serialize, Deserialize)]
116pub struct FeedbackStats {
117 compression_count: usize,
119 focus_count: usize,
121 retrieval_count: usize,
123 avg_compression_rating: f32,
125 avg_focus_rating: f32,
127 avg_retrieval_rating: f32,
129 compression_accept_rate: f32,
131 focus_accuracy_rate: f32,
133 retrieval_relevance_rate: f32,
135}
136
137impl Default for FeedbackStats {
138 fn default() -> Self {
139 Self {
140 compression_count: 0,
141 focus_count: 0,
142 retrieval_count: 0,
143 avg_compression_rating: 0.0,
144 avg_focus_rating: 0.0,
145 avg_retrieval_rating: 0.0,
146 compression_accept_rate: 0.0,
147 focus_accuracy_rate: 0.0,
148 retrieval_relevance_rate: 0.0,
149 }
150 }
151}
152
153impl AdaptiveLearner {
154 pub fn new() -> Self {
156 Self {
157 compression_feedback: Vec::new(),
158 focus_feedback: Vec::new(),
159 retrieval_feedback: Vec::new(),
160 preferences: AdaptivePreferences::default(),
161 stats: FeedbackStats::default(),
162 }
163 }
164
165 pub fn record_compression_feedback(
167 &mut self,
168 session_id: &str,
169 original_tokens: u32,
170 compressed_tokens: u32,
171 stage: &str,
172 rating: u8,
173 accepted: bool,
174 comments: Option<String>,
175 ) {
176 let feedback = CompressionFeedback {
177 timestamp: Utc::now(),
178 session_id: session_id.to_string(),
179 original_tokens,
180 compressed_tokens,
181 stage: stage.to_string(),
182 rating: rating.clamp(1, 5),
183 comments,
184 accepted,
185 };
186
187 self.compression_feedback.push(feedback);
188 self.update_compression_preferences();
189 }
190
191 pub fn record_focus_feedback(
193 &mut self,
194 session_id: &str,
195 focus_topic: &str,
196 rating: u8,
197 accurate: bool,
198 suggested_correction: Option<String>,
199 ) {
200 let feedback = FocusFeedback {
201 timestamp: Utc::now(),
202 session_id: session_id.to_string(),
203 focus_topic: focus_topic.to_string(),
204 rating: rating.clamp(1, 5),
205 accurate,
206 suggested_correction,
207 };
208
209 self.focus_feedback.push(feedback);
210 self.update_focus_preferences();
211 }
212
213 pub fn record_retrieval_feedback(
215 &mut self,
216 session_id: &str,
217 memory_id: &str,
218 memory_content: &str,
219 rating: u8,
220 relevant: bool,
221 suggested_context: Option<String>,
222 ) {
223 let truncated_content = if memory_content.len() > 100 {
225 memory_content.chars().take(100).collect::<String>()
226 } else {
227 memory_content.to_string()
228 };
229
230 let feedback = RetrievalFeedback {
231 timestamp: Utc::now(),
232 session_id: session_id.to_string(),
233 memory_id: memory_id.to_string(),
234 memory_content: truncated_content,
235 rating: rating.clamp(1, 5),
236 relevant,
237 suggested_context,
238 };
239
240 self.retrieval_feedback.push(feedback);
241 self.update_retrieval_preferences();
242 }
243
244 fn update_compression_preferences(&mut self) {
246 if self.compression_feedback.len() < 5 {
247 return; }
249
250 let stage_acceptance: HashMap<String, f32> = self
252 .compression_feedback
253 .iter()
254 .fold(HashMap::new(), |mut acc, f| {
255 let entry = acc.entry(f.stage.clone()).or_insert((0usize, 0usize));
256 if f.accepted {
257 entry.0 += 1;
258 }
259 entry.1 += 1;
260 acc
261 })
262 .iter()
263 .map(|(stage, (accepted, total))| {
264 (stage.clone(), *accepted as f32 / *total as f32)
265 })
266 .collect();
267
268 let best_stage = stage_acceptance
270 .iter()
271 .max_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal))
272 .map(|(s, _)| s.clone())
273 .unwrap_or_else(|| "RemoveLowPriority".to_string());
274
275 self.preferences.preferred_stage = best_stage;
276
277 let total_rating = self.compression_feedback.iter().map(|f| f.rating as f32).sum::<f32>();
279 let total_accepted = self.compression_feedback.iter().filter(|f| f.accepted).count();
280
281 self.stats.avg_compression_rating = total_rating / self.compression_feedback.len() as f32;
282 self.stats.compression_accept_rate = total_accepted as f32 / self.compression_feedback.len() as f32;
283
284 if self.stats.avg_compression_rating > 4.0 && self.stats.compression_accept_rate > 0.8 {
286 self.preferences.compression_aggressiveness = 0.7;
288 } else if self.stats.avg_compression_rating < 3.0 || self.stats.compression_accept_rate < 0.5 {
289 self.preferences.compression_aggressiveness = 0.3;
291 }
292
293 self.preferences.last_updated = Utc::now();
294 self.stats.compression_count = self.compression_feedback.len();
295 }
296
297 fn update_focus_preferences(&mut self) {
299 if self.focus_feedback.len() < 5 {
300 return;
301 }
302
303 let accurate_count = self.focus_feedback.iter().filter(|f| f.accurate).count();
305 self.stats.focus_accuracy_rate = accurate_count as f32 / self.focus_feedback.len() as f32;
306
307 let total_rating = self.focus_feedback.iter().map(|f| f.rating as f32).sum::<f32>();
309 self.stats.avg_focus_rating = total_rating / self.focus_feedback.len() as f32;
310
311 if self.stats.focus_accuracy_rate > 0.9 {
313 self.preferences.focus_sensitivity = 0.7;
315 } else if self.stats.focus_accuracy_rate < 0.5 {
316 self.preferences.focus_sensitivity = 0.9;
318 } else if self.stats.focus_accuracy_rate > 0.7 && self.stats.avg_focus_rating > 4.0 {
319 self.preferences.focus_sensitivity = 0.5;
321 }
322
323 self.preferences.last_updated = Utc::now();
324 self.stats.focus_count = self.focus_feedback.len();
325 }
326
327 fn update_retrieval_preferences(&mut self) {
329 if self.retrieval_feedback.len() < 5 {
330 return;
331 }
332
333 let relevant_count = self.retrieval_feedback.iter().filter(|f| f.relevant).count();
335 self.stats.retrieval_relevance_rate = relevant_count as f32 / self.retrieval_feedback.len() as f32;
336
337 let total_rating = self.retrieval_feedback.iter().map(|f| f.rating as f32).sum::<f32>();
339 self.stats.avg_retrieval_rating = total_rating / self.retrieval_feedback.len() as f32;
340
341 if self.stats.retrieval_relevance_rate < 0.5 {
343 self.preferences.retrieval_weights.insert("focus".to_string(), 0.35);
345 self.preferences.retrieval_weights.insert("tfidf".to_string(), 0.25);
346 } else if self.stats.retrieval_relevance_rate > 0.8 && self.stats.avg_retrieval_rating > 4.0 {
347 self.preferences.retrieval_weights.insert("focus".to_string(), 0.25);
349 self.preferences.retrieval_weights.insert("tfidf".to_string(), 0.30);
350 }
351
352 self.preferences.last_updated = Utc::now();
353 self.stats.retrieval_count = self.retrieval_feedback.len();
354 }
355
356 pub fn get_preferences(&self) -> &AdaptivePreferences {
358 &self.preferences
359 }
360
361 pub fn get_stats(&self) -> &FeedbackStats {
363 &self.stats
364 }
365
366 pub fn get_compression_aggressiveness(&self) -> f32 {
368 self.preferences.compression_aggressiveness
369 }
370
371 pub fn get_focus_sensitivity(&self) -> f32 {
373 self.preferences.focus_sensitivity
374 }
375
376 pub fn get_preferred_stage(&self) -> &str {
378 &self.preferences.preferred_stage
379 }
380
381 pub fn get_retrieval_weight(&self, factor: &str) -> f32 {
383 self.preferences.retrieval_weights.get(factor).copied().unwrap_or(0.25)
384 }
385
386 pub fn export_feedback(&self) -> FeedbackExport {
388 FeedbackExport {
389 compression_feedback: self.compression_feedback.clone(),
390 focus_feedback: self.focus_feedback.clone(),
391 retrieval_feedback: self.retrieval_feedback.clone(),
392 preferences: self.preferences.clone(),
393 stats: self.stats.clone(),
394 }
395 }
396
397 pub fn import_feedback(&mut self, export: FeedbackExport) {
399 self.compression_feedback = export.compression_feedback;
400 self.focus_feedback = export.focus_feedback;
401 self.retrieval_feedback = export.retrieval_feedback;
402 self.preferences = export.preferences;
403 self.stats = export.stats;
404
405 self.update_compression_preferences();
407 self.update_focus_preferences();
408 self.update_retrieval_preferences();
409 }
410
411 pub fn prune_old_feedback(&mut self) {
413 if self.compression_feedback.len() > 100 {
414 self.compression_feedback = self.compression_feedback.iter().rev().take(100).rev().cloned().collect();
415 }
416 if self.focus_feedback.len() > 100 {
417 self.focus_feedback = self.focus_feedback.iter().rev().take(100).rev().cloned().collect();
418 }
419 if self.retrieval_feedback.len() > 100 {
420 self.retrieval_feedback = self.retrieval_feedback.iter().rev().take(100).rev().cloned().collect();
421 }
422 }
423
424 pub fn generate_report(&self) -> String {
426 let mut report = String::from("【自适应学习报告】\n\n");
427
428 report.push_str(&format!(
429 "压缩偏好:\n 激进程度: {:.0}%\n 首选阶段: {}\n 平均评分: {:.1}\n 接受率: {:.0}%\n\n",
430 self.preferences.compression_aggressiveness * 100.0,
431 self.preferences.preferred_stage,
432 self.stats.avg_compression_rating,
433 self.stats.compression_accept_rate * 100.0
434 ));
435
436 report.push_str(&format!(
437 "聚焦检测:\n 灵敏度: {:.0}%\n 准确率: {:.0}%\n 平均评分: {:.1}\n\n",
438 self.preferences.focus_sensitivity * 100.0,
439 self.stats.focus_accuracy_rate * 100.0,
440 self.stats.avg_focus_rating
441 ));
442
443 report.push_str(&format!(
444 "记忆检索:\n 相关率: {:.0}%\n 平均评分: {:.1}\n\n",
445 self.stats.retrieval_relevance_rate * 100.0,
446 self.stats.avg_retrieval_rating
447 ));
448
449 report.push_str(&format!(
450 "反馈统计:\n 压缩反馈: {} 条\n 聚焦反馈: {} 条\n 检索反馈: {} 条\n",
451 self.stats.compression_count,
452 self.stats.focus_count,
453 self.stats.retrieval_count
454 ));
455
456 report
457 }
458}
459
460impl Default for AdaptiveLearner {
461 fn default() -> Self {
462 Self::new()
463 }
464}
465
466#[derive(Debug, Clone, Serialize, Deserialize)]
468pub struct FeedbackExport {
469 compression_feedback: Vec<CompressionFeedback>,
470 focus_feedback: Vec<FocusFeedback>,
471 retrieval_feedback: Vec<RetrievalFeedback>,
472 preferences: AdaptivePreferences,
473 stats: FeedbackStats,
474}
475
476#[cfg(test)]
477mod tests {
478 use super::*;
479
480 #[test]
481 fn test_adaptive_learner_creation() {
482 let learner = AdaptiveLearner::new();
483 assert_eq!(learner.preferences.compression_aggressiveness, 0.5);
484 }
485
486 #[test]
487 fn test_compression_feedback_recording() {
488 let mut learner = AdaptiveLearner::new();
489 learner.record_compression_feedback(
490 "test-session",
491 10000,
492 8000,
493 "RemoveLowPriority",
494 4,
495 true,
496 None,
497 );
498 assert_eq!(learner.compression_feedback.len(), 1);
499 }
500
501 #[test]
502 fn test_focus_feedback_recording() {
503 let mut learner = AdaptiveLearner::new();
504 learner.record_focus_feedback(
505 "test-session",
506 "database optimization",
507 5,
508 true,
509 None,
510 );
511 assert_eq!(learner.focus_feedback.len(), 1);
512 }
513
514 #[test]
515 fn test_retrieval_feedback_recording() {
516 let mut learner = AdaptiveLearner::new();
517 learner.record_retrieval_feedback(
518 "test-session",
519 "memory-123",
520 "Test memory content",
521 4,
522 true,
523 None,
524 );
525 assert_eq!(learner.retrieval_feedback.len(), 1);
526 }
527
528 #[test]
529 fn test_preferences_default() {
530 let prefs = AdaptivePreferences::default();
531 assert_eq!(prefs.compression_aggressiveness, 0.5);
532 assert_eq!(prefs.focus_sensitivity, 0.7);
533 }
534
535 #[test]
536 fn test_feedback_pruning() {
537 let mut learner = AdaptiveLearner::new();
538
539 for i in 0..150 {
541 learner.record_compression_feedback(
542 &format!("session-{}", i),
543 10000,
544 8000,
545 "RemoveLowPriority",
546 4,
547 true,
548 None,
549 );
550 }
551
552 learner.prune_old_feedback();
553 assert_eq!(learner.compression_feedback.len(), 100);
554 }
555
556 #[test]
557 fn test_report_generation() {
558 let learner = AdaptiveLearner::new();
559 let report = learner.generate_report();
560 assert!(report.contains("自适应学习报告"));
561 }
562}