matrixcode_core/compress/
focus.rs1use crate::memory::ExtractedKeywords;
7use crate::providers::{ContentBlock, Message, MessageContent, Role};
8use super::focus_config::FocusTrackerConfig;
9
10#[derive(Debug, Clone)]
12pub struct ConversationFocus {
13 pub current_topic: Option<String>,
15 pub current_question: Option<String>,
17 pub recent_context: Vec<String>,
19 pub topic_transitions: Vec<TopicTransition>,
21 pub detected_at: usize, }
24
25#[derive(Debug, Clone)]
27pub struct TopicTransition {
28 pub from_topic: String,
29 pub to_topic: String,
30 pub message_index: usize,
31 pub transition_keyword: String,
32}
33
34pub struct FocusTracker {
36 config: FocusTrackerConfig,
38}
39
40impl Default for FocusTracker {
41 fn default() -> Self {
42 Self::new()
43 }
44}
45
46impl FocusTracker {
47 pub fn new() -> Self {
49 Self {
50 config: FocusTrackerConfig::default(),
51 }
52 }
53
54 pub fn with_config(config: FocusTrackerConfig) -> Self {
56 Self { config }
57 }
58
59 pub fn config(&self) -> &FocusTrackerConfig {
61 &self.config
62 }
63
64 pub fn config_mut(&mut self) -> &mut FocusTrackerConfig {
66 &mut self.config
67 }
68
69 pub fn set_current_keywords(&mut self, keywords: &ExtractedKeywords) {
73 self.config.set_keywords(keywords);
74 }
75
76 pub fn merge_keywords(&mut self, additional: &ExtractedKeywords) {
78 self.config.merge_keywords(additional);
79 }
80
81 pub fn clear_keywords(&mut self) {
83 self.config.clear_keywords();
84 }
85
86 pub fn detect_focus(&self, messages: &[Message]) -> ConversationFocus {
88 self.detect_focus_with_window(messages, self.config.focus_window_size)
89 }
90
91 pub fn detect_focus_with_window(&self, messages: &[Message], window_size: usize) -> ConversationFocus {
93 let recent_start = messages.len().saturating_sub(window_size);
94 let recent_messages = &messages[recent_start..];
95
96 let mut focus = ConversationFocus {
97 current_topic: None,
98 current_question: None,
99 recent_context: Vec::new(),
100 topic_transitions: Vec::new(),
101 detected_at: messages.len().saturating_sub(1),
102 };
103
104 for (_idx, msg) in recent_messages.iter().enumerate().rev() {
106 if let Some(key_point) = self.extract_key_point(msg) {
107 focus.recent_context.push(key_point);
108 if focus.recent_context.len() >= self.config.max_recent_context_count {
109 break;
110 }
111 }
112 }
113
114 for msg in recent_messages.iter().rev() {
116 if matches!(msg.role, Role::User) {
117 if let Some(question) = self.extract_current_question(msg) {
118 focus.current_question = Some(question);
119 break;
120 }
121 }
122 }
123
124 focus.topic_transitions = self.detect_topic_transitions(messages);
126
127 if let Some(last_transition) = focus.topic_transitions.last() {
129 focus.current_topic = Some(last_transition.to_topic.clone());
130 } else {
131 focus.current_topic = self.extract_initial_topic(messages);
133 }
134
135 focus
136 }
137
138 fn extract_key_point(&self, message: &Message) -> Option<String> {
140 match &message.content {
141 MessageContent::Text(text) => {
142 let sentences: Vec<&str> = text.split(|c| c == '.' || c == '。' || c == '\n')
144 .filter(|s| s.trim().len() > self.config.min_substantial_text_length)
145 .collect();
146
147 sentences.first().map(|s| s.trim().to_string())
148 }
149 MessageContent::Blocks(blocks) => {
150 for block in blocks {
151 if let ContentBlock::Text { text } = block {
152 if text.len() > self.config.min_substantial_text_length {
153 return Some(text.split('\n').next()?.trim().to_string());
154 }
155 }
156 }
157 None
158 }
159 }
160 }
161
162 fn extract_current_question(&self, message: &Message) -> Option<String> {
164 match &message.content {
165 MessageContent::Text(text) => {
166 if self.config.matches_question(text) {
168 let question = text.chars()
170 .take(self.config.max_question_extract_length)
171 .collect::<String>();
172 return Some(question.trim().to_string());
173 }
174
175 if self.config.matches_task(text) {
177 let task = text.chars()
178 .take(self.config.max_question_extract_length)
179 .collect::<String>();
180 return Some(task.trim().to_string());
181 }
182
183 if text.len() > self.config.min_substantial_text_length * 2 {
185 Some(text.chars()
186 .take(self.config.max_question_extract_length)
187 .collect::<String>()
188 .trim()
189 .to_string())
190 } else {
191 None
192 }
193 }
194 MessageContent::Blocks(blocks) => {
195 for block in blocks {
196 if let ContentBlock::Text { text } = block {
197 if text.len() > self.config.min_substantial_text_length {
198 return Some(text.chars()
199 .take(self.config.max_question_extract_length)
200 .collect::<String>());
201 }
202 }
203 }
204 None
205 }
206 }
207 }
208
209 fn detect_topic_transitions(&self, messages: &[Message]) -> Vec<TopicTransition> {
211 let mut transitions = Vec::new();
212 let mut prev_topic = String::new();
213
214 let transition_keywords = self.config.transition_keywords();
216
217 for (idx, msg) in messages.iter().enumerate() {
218 if matches!(msg.role, Role::User) {
219 let text = match &msg.content {
220 MessageContent::Text(t) => t.clone(),
221 MessageContent::Blocks(blocks) => {
222 blocks.iter()
223 .filter_map(|b| {
224 if let ContentBlock::Text { text } = b {
225 Some(text.clone())
226 } else {
227 None
228 }
229 })
230 .collect::<Vec<_>>()
231 .join(" ")
232 }
233 };
234
235 let lower = text.to_lowercase();
236
237 for keyword in &transition_keywords {
239 if lower.contains(&keyword.to_lowercase()) {
240 let new_topic = self.extract_topic_from_message(&text);
242
243 if !prev_topic.is_empty() && new_topic != prev_topic {
244 transitions.push(TopicTransition {
245 from_topic: prev_topic.clone(),
246 to_topic: new_topic.clone(),
247 message_index: idx,
248 transition_keyword: keyword.clone(),
249 });
250 }
251
252 prev_topic = new_topic;
253 break;
254 }
255 }
256
257 if prev_topic.is_empty() {
259 prev_topic = self.extract_topic_from_message(&text);
260 }
261 }
262 }
263
264 transitions
265 }
266
267 fn extract_topic_from_message(&self, text: &str) -> String {
269 let found = self.config.find_tech_keywords(text);
271
272 if found.is_empty() {
273 text.split_whitespace()
275 .take(self.config.fallback_topic_word_count)
276 .collect::<Vec<_>>()
277 .join(" ")
278 } else {
279 found.join(", ")
280 }
281 }
282
283 fn extract_initial_topic(&self, messages: &[Message]) -> Option<String> {
285 for msg in messages {
286 if matches!(msg.role, Role::User) {
287 let text = match &msg.content {
288 MessageContent::Text(t) => t.clone(),
289 MessageContent::Blocks(blocks) => {
290 blocks.iter()
291 .filter_map(|b| {
292 if let ContentBlock::Text { text } = b {
293 Some(text.clone())
294 } else {
295 None
296 }
297 })
298 .collect::<Vec<_>>()
299 .join(" ")
300 }
301 };
302
303 if text.len() > self.config.min_substantial_text_length {
304 return Some(self.extract_topic_from_message(&text));
305 }
306 }
307 }
308 None
309 }
310
311 pub fn focus_score(&self, message: &Message, focus: &ConversationFocus) -> f32 {
313 let mut score = 0.0;
314
315 let text = match &message.content {
317 MessageContent::Text(t) => t.clone(),
318 MessageContent::Blocks(blocks) => {
319 blocks.iter()
320 .filter_map(|b| {
321 if let ContentBlock::Text { text } = b {
322 Some(text.clone())
323 } else {
324 None
325 }
326 })
327 .collect::<Vec<_>>()
328 .join(" ")
329 }
330 };
331
332 let lower = text.to_lowercase();
333
334 if let Some(topic) = &focus.current_topic {
336 let topic_keywords: Vec<&str> = topic.split(", ").collect();
337 for kw in topic_keywords {
338 if lower.contains(kw) {
339 score += 0.3;
340 }
341 }
342 }
343
344 if let Some(question) = &focus.current_question {
346 let question_lower = question.to_lowercase();
347 let words: Vec<&str> = question_lower.split_whitespace().collect();
348 for word in words {
349 if word.len() > 3 && lower.contains(word) {
350 score += 0.1;
351 }
352 }
353 }
354
355 if let Some(key_point) = self.extract_key_point(message) {
357 if focus.recent_context.contains(&key_point) {
358 score += 0.5;
359 }
360 }
361
362 score = (score * self.config.focus_score_boost).min(self.config.max_focus_score);
364
365 score
366 }
367
368 pub fn focus_score_with_keywords(&self, message: &Message, focus: &ConversationFocus) -> f32 {
380 let keywords = self.config.get_keywords();
381
382 let text = match &message.content {
384 MessageContent::Text(t) => t.clone(),
385 MessageContent::Blocks(blocks) => {
386 blocks.iter()
387 .filter_map(|b| {
388 if let ContentBlock::Text { text } = b {
389 Some(text.clone())
390 } else {
391 None
392 }
393 })
394 .collect::<Vec<_>>()
395 .join(" ")
396 }
397 };
398
399 let lower = text.to_lowercase();
400 let mut score = 0.0;
401
402 if let Some(kw) = keywords {
404 for keyword in &kw.transition {
406 if lower.contains(&keyword.to_lowercase()) {
407 score += 0.2; }
409 }
410
411 for keyword in &kw.question {
413 if lower.contains(&keyword.to_lowercase()) {
414 score += 0.3; }
416 }
417
418 for keyword in &kw.task {
420 if lower.contains(&keyword.to_lowercase()) {
421 score += 0.25; }
423 }
424
425 for keyword in &kw.tech {
427 if lower.contains(&keyword.to_lowercase()) {
428 score += 0.15; }
430 }
431 } else {
432 return self.focus_score(message, focus);
434 }
435
436 if let Some(topic) = &focus.current_topic {
438 let topic_keywords: Vec<&str> = topic.split(", ").collect();
439 for kw in topic_keywords {
440 if lower.contains(&kw.to_lowercase()) {
441 score += 0.1;
442 }
443 }
444 }
445
446 if let Some(question) = &focus.current_question {
447 let question_lower = question.to_lowercase();
448 for word in question_lower.split_whitespace() {
449 if word.len() > 3 && lower.contains(word) {
450 score += 0.05;
451 }
452 }
453 }
454
455 score = (score * self.config.focus_score_boost).min(self.config.max_focus_score);
457
458 score
459 }
460
461 pub fn create_focus_message(&self, focus: &ConversationFocus) -> Message {
463 let mut content_parts = Vec::new();
464
465 if let Some(topic) = &focus.current_topic {
467 content_parts.push(format!("当前话题: {}", topic));
468 }
469
470 if let Some(question) = &focus.current_question {
472 content_parts.push(format!("当前问题/任务: {}", question));
473 }
474
475 if !focus.recent_context.is_empty() {
477 content_parts.push(format!("最近上下文摘要: {}", focus.recent_context.join(" | ")));
478 }
479
480 if !focus.topic_transitions.is_empty() {
482 let transitions: Vec<String> = focus.topic_transitions.iter()
483 .map(|t| format!("{} -> {}", t.from_topic, t.to_topic))
484 .collect();
485 content_parts.push(format!("话题转换历史: {}", transitions.join(", ")));
486 }
487
488 let content = if content_parts.is_empty() {
489 "[焦点追踪系统初始化]".to_string()
490 } else {
491 format!("【焦点上下文】\n{}\n请基于上述焦点继续对话。", content_parts.join("\n"))
492 };
493
494 Message {
495 role: Role::System,
496 content: MessageContent::Text(content),
497 }
498 }
499}
500
501#[cfg(test)]
502mod tests {
503 use super::*;
504
505 #[test]
506 fn test_focus_tracker_creation() {
507 let tracker = FocusTracker::new();
508 assert!(tracker.config().validate());
509 }
510
511 #[test]
512 fn test_focus_tracker_with_custom_config() {
513 let config = FocusTrackerConfig::simple_conversation();
514 let tracker = FocusTracker::with_config(config);
515 assert_eq!(tracker.config().focus_window_size, 5);
516 }
517
518 #[test]
519 fn test_detect_focus() {
520 let tracker = FocusTracker::new();
521 let messages = vec![
522 Message {
523 role: Role::User,
524 content: MessageContent::Text("如何优化 Rust 性能?".to_string()),
525 },
526 Message {
527 role: Role::Assistant,
528 content: MessageContent::Text("我来帮你优化 Rust 代码性能。".to_string()),
529 },
530 Message {
531 role: Role::User,
532 content: MessageContent::Text("帮我实现一个压缩算法".to_string()),
533 },
534 ];
535
536 let focus = tracker.detect_focus(&messages);
537 assert!(focus.current_question.is_some());
538 assert_eq!(focus.topic_transitions.len(), 0);
539 }
540
541 #[test]
542 fn test_focus_score() {
543 let tracker = FocusTracker::new();
544 let messages = vec![
545 Message {
546 role: Role::User,
547 content: MessageContent::Text("如何优化 Rust 性能?".to_string()),
548 },
549 ];
550
551 let focus = tracker.detect_focus(&messages);
552
553 let relevant_message = Message {
554 role: Role::Assistant,
555 content: MessageContent::Text("Rust 性能优化的关键是...".to_string()),
556 };
557
558 let score = tracker.focus_score(&relevant_message, &focus);
559 assert!(score > 0.0);
560 }
561
562 #[test]
563 fn test_keywords_integration() {
564 let mut tracker = FocusTracker::new();
565
566 let keywords = ExtractedKeywords {
568 transition: vec!["custom_transition".to_string()],
569 question: vec!["custom_question".to_string()],
570 task: vec!["custom_task".to_string()],
571 tech: vec!["customtech".to_string()],
572 };
573 tracker.set_current_keywords(&keywords);
574
575 let tech_keywords = tracker.config().tech_keywords();
577 assert!(tech_keywords.contains(&"customtech".to_string()));
578 }
579
580 #[test]
581 fn test_fallback_keywords() {
582 let tracker = FocusTracker::new();
583
584 let keywords = tracker.config().transition_keywords();
586 assert!(!keywords.is_empty());
587 assert!(keywords.contains(&"however".to_string()));
588 }
589
590 #[test]
591 fn test_matches_keywords() {
592 let tracker = FocusTracker::new();
593
594 assert!(tracker.config().matches_question("How do I do this?"));
596 assert!(tracker.config().matches_task("Please implement this"));
597 assert!(tracker.config().matches_transition("However, let's move on"));
598 }
599
600 #[test]
601 fn test_topic_extraction() {
602 let tracker = FocusTracker::new();
603
604 let topic = tracker.extract_topic_from_message("使用 Rust 和 Python 开发项目");
606 assert!(topic.contains("rust"));
607 assert!(topic.contains("python"));
608
609 let topic = tracker.extract_topic_from_message("随便聊聊天气");
611 assert!(!topic.is_empty());
612 }
613
614 #[test]
615 fn test_clear_keywords() {
616 let mut tracker = FocusTracker::new();
617
618 let keywords = ExtractedKeywords {
620 transition: vec!["test".to_string()],
621 question: vec![],
622 task: vec![],
623 tech: vec![],
624 };
625 tracker.set_current_keywords(&keywords);
626 assert!(tracker.config().get_keywords().is_some());
627
628 tracker.clear_keywords();
630 assert!(tracker.config().get_keywords().is_none());
631
632 assert!(tracker.config().transition_keywords().contains(&"however".to_string()));
634 }
635}