1use anyhow::Result;
7use serde::{Deserialize, Serialize};
8
9use crate::providers::{ChatRequest, ContentBlock, Message, MessageContent, Provider, Role};
10use super::focus::{ConversationFocus, TopicTransition};
11use super::focus_config::FocusTrackerConfig;
12
13const FOCUS_ANALYSIS_PROMPT: &str = r#"你是焦点分析助手。分析新消息与当前会话焦点的关系。
15
16## 分析维度
17
181. **relevance** (0.0-1.0): 与当前焦点的相关性
19 - 1.0: 直接回答当前问题或继续当前任务
20 - 0.7-0.9: 高度相关,提供重要上下文
21 - 0.4-0.6: 中等相关,有联系但不直接
22 - 0.1-0.3: 低相关,可能偏离话题
23 - 0.0: 完全不相关或话题已切换
24
252. **is_focus_update** (true/false): 是否需要更新焦点
26 - true: 当话题明显转换、新问题提出、任务切换时
27 - false: 继续当前话题时
28
293. **语义差异检测**: 注意区分相似但不同的概念
30 - 例如: "压缩" vs "解压缩" 是不同任务
31 - 例如: "优化性能" vs "优化内存" 是不同焦点
32
33## 输出格式(严格 JSON)
34
35```json
36{
37 "relevance": 0.8,
38 "is_focus_update": false,
39 "new_topic": "新话题名称(如果需要更新)",
40 "new_question": "新问题(如果需要更新)",
41 "context_to_add": "需要添加到上下文的关键信息",
42 "reason": "判断理由简述"
43}
44```
45
46## 规则
47
481. 只返回 JSON,不要其他解释
492. 如果不需要更新焦点,`new_topic` 和 `new_question` 可以省略
503. `context_to_add` 只在有重要上下文信息时填写
514. relevance 应基于语义理解,不是简单的关键词匹配"#;
52
53#[derive(Debug, Clone, Serialize, Deserialize)]
55pub struct FocusAnalysisResult {
56 pub relevance: f32,
58 pub is_focus_update: bool,
60 #[serde(skip_serializing_if = "Option::is_none")]
62 pub new_topic: Option<String>,
63 #[serde(skip_serializing_if = "Option::is_none")]
65 pub new_question: Option<String>,
66 #[serde(skip_serializing_if = "Option::is_none")]
68 pub context_to_add: Option<String>,
69 pub reason: String,
71}
72
73impl Default for FocusAnalysisResult {
74 fn default() -> Self {
75 Self {
76 relevance: 0.5,
77 is_focus_update: false,
78 new_topic: None,
79 new_question: None,
80 context_to_add: None,
81 reason: "Default result (AI analysis not performed)".to_string(),
82 }
83 }
84}
85
86pub struct AiFocusTracker {
88 provider: Box<dyn Provider>,
90 model: String,
92 current_focus: Option<ConversationFocus>,
94 config: FocusTrackerConfig,
96 analysis_cache: Vec<(String, FocusAnalysisResult)>,
98 max_cache_size: usize,
100}
101
102impl AiFocusTracker {
103 pub fn new(provider: Box<dyn Provider>, model: String) -> Self {
109 Self {
110 provider,
111 model,
112 current_focus: None,
113 config: FocusTrackerConfig::default(),
114 analysis_cache: Vec::new(),
115 max_cache_size: 50,
116 }
117 }
118
119 pub fn with_config(provider: Box<dyn Provider>, model: String, config: FocusTrackerConfig) -> Self {
121 Self {
122 provider,
123 model,
124 current_focus: None,
125 config,
126 analysis_cache: Vec::new(),
127 max_cache_size: 50,
128 }
129 }
130
131 pub fn new_minimal(model: String) -> Self {
133 Self {
134 provider: crate::create_minimal_provider(&model),
135 model,
136 current_focus: None,
137 config: FocusTrackerConfig::default(),
138 analysis_cache: Vec::new(),
139 max_cache_size: 50,
140 }
141 }
142
143 pub fn current_focus(&self) -> Option<&ConversationFocus> {
145 self.current_focus.as_ref()
146 }
147
148 pub fn set_focus(&mut self, focus: ConversationFocus) {
150 self.current_focus = Some(focus);
151 }
152
153 pub fn clear_focus(&mut self) {
155 self.current_focus = None;
156 self.analysis_cache.clear();
157 }
158
159 pub fn config(&self) -> &FocusTrackerConfig {
161 &self.config
162 }
163
164 pub fn config_mut(&mut self) -> &mut FocusTrackerConfig {
166 &mut self.config
167 }
168
169 pub async fn analyze_message(&mut self, message: &Message) -> Result<FocusAnalysisResult> {
180 let message_key = self.message_cache_key(message);
182 if let Some((_, cached)) = self.analysis_cache.iter().find(|(k, _)| k == &message_key) {
183 log::debug!("Using cached focus analysis result");
184 return Ok(cached.clone());
185 }
186
187 let prompt = self.build_focus_analysis_prompt(message);
189
190 let response = self.call_ai(&prompt).await?;
192
193 let result = self.parse_analysis_result(&response)?;
195
196 if result.is_focus_update {
198 self.update_focus_from_result(&result, message);
199 }
200
201 self.cache_result(message_key, result.clone());
203
204 Ok(result)
205 }
206
207 pub async fn analyze_key_messages(&mut self, messages: &[Message]) -> Result<Vec<(usize, FocusAnalysisResult)>> {
218 let mut results = Vec::new();
219
220 for (idx, msg) in messages.iter().enumerate() {
222 let is_key = matches!(msg.role, Role::User)
223 || idx == 0
224 || idx == messages.len() - 1;
225
226 if is_key {
227 let result = self.analyze_message(msg).await?;
228 results.push((idx, result));
229 }
230 }
231
232 Ok(results)
233 }
234
235 fn build_focus_analysis_prompt(&self, message: &Message) -> String {
237 let current_focus_text = self.format_current_focus();
238 let message_text = self.format_message(message);
239
240 format!(
241 "分析新消息与当前会话焦点的关系:\n\n{}\n\n新消息:\n{}\n\n请返回 JSON 格式分析结果。",
242 current_focus_text,
243 message_text
244 )
245 }
246
247 fn format_current_focus(&self) -> String {
249 match &self.current_focus {
250 Some(focus) => {
251 let mut parts = Vec::new();
252
253 if let Some(topic) = &focus.current_topic {
254 parts.push(format!("当前话题: {}", topic));
255 }
256
257 if let Some(question) = &focus.current_question {
258 parts.push(format!("当前问题/任务: {}", question));
259 }
260
261 if !focus.recent_context.is_empty() {
262 parts.push(format!("最近上下文: {}", focus.recent_context.join(" | ")));
263 }
264
265 if !focus.topic_transitions.is_empty() {
266 let transitions: Vec<String> = focus.topic_transitions.iter()
267 .map(|t| format!("{} -> {}", t.from_topic, t.to_topic))
268 .collect();
269 parts.push(format!("话题转换历史: {}", transitions.join(", ")));
270 }
271
272 if parts.is_empty() {
273 "当前焦点: (尚未建立明确焦点)".to_string()
274 } else {
275 format!("当前焦点:\n{}", parts.join("\n"))
276 }
277 }
278 None => "当前焦点: (尚未建立明确焦点,这是对话开始)".to_string(),
279 }
280 }
281
282 fn format_message(&self, message: &Message) -> String {
284 let role = match message.role {
285 Role::User => "用户",
286 Role::Assistant => "助手",
287 Role::System => "系统",
288 Role::Tool => "工具",
289 };
290
291 let content = match &message.content {
292 MessageContent::Text(text) => text.clone(),
293 MessageContent::Blocks(blocks) => {
294 blocks.iter()
295 .filter_map(|b| {
296 if let ContentBlock::Text { text } = b {
297 Some(text.clone())
298 } else {
299 None
300 }
301 })
302 .collect::<Vec<_>>()
303 .join("\n")
304 }
305 };
306
307 let truncated = if content.len() > 500 {
309 format!("{}... (已截断)", &content[..500])
310 } else {
311 content
312 };
313
314 format!("角色: {}\n内容: {}", role, truncated)
315 }
316
317 async fn call_ai(&self, prompt: &str) -> Result<String> {
319 let request = ChatRequest {
320 messages: vec![Message {
321 role: Role::User,
322 content: MessageContent::Text(prompt.to_string()),
323 }],
324 tools: vec![],
325 system: Some(FOCUS_ANALYSIS_PROMPT.to_string()),
326 think: false,
327 max_tokens: 256, server_tools: vec![],
329 enable_caching: false,
330 };
331
332 let response = self.provider.chat(request).await?;
333
334 let text = response.content.iter()
336 .filter_map(|b| {
337 if let ContentBlock::Text { text } = b {
338 Some(text.clone())
339 } else {
340 None
341 }
342 })
343 .collect::<Vec<_>>()
344 .join("");
345
346 Ok(text)
347 }
348
349 fn parse_analysis_result(&self, response: &str) -> Result<FocusAnalysisResult> {
351 let cleaned = response
353 .trim()
354 .trim_start_matches("```json")
355 .trim_start_matches("```")
356 .trim_end_matches("```")
357 .trim();
358
359 let result: FocusAnalysisResult = serde_json::from_str(cleaned)?;
361
362 let validated = FocusAnalysisResult {
364 relevance: result.relevance.clamp(0.0, 1.0),
365 is_focus_update: result.is_focus_update,
366 new_topic: result.new_topic,
367 new_question: result.new_question,
368 context_to_add: result.context_to_add,
369 reason: result.reason,
370 };
371
372 Ok(validated)
373 }
374
375 fn update_focus_from_result(&mut self, result: &FocusAnalysisResult, message: &Message) {
377 let message_idx = self.current_focus.as_ref()
378 .map(|f| f.detected_at + 1)
379 .unwrap_or(0);
380
381 let message_context = self.extract_message_context(message);
383
384 let new_focus = match &self.current_focus {
385 Some(existing) => {
386 let mut new_focus = ConversationFocus {
388 current_topic: result.new_topic.clone().or(existing.current_topic.clone()),
389 current_question: result.new_question.clone().or(existing.current_question.clone()),
390 recent_context: existing.recent_context.clone(),
391 topic_transitions: existing.topic_transitions.clone(),
392 detected_at: message_idx,
393 };
394
395 if let Some(ctx) = &result.context_to_add {
397 new_focus.recent_context.push(ctx.clone());
398 if new_focus.recent_context.len() > self.config.max_recent_context_count {
399 new_focus.recent_context.remove(0);
400 }
401 }
402
403 if let (Some(new_topic), Some(old_topic)) = (&result.new_topic, &existing.current_topic) {
405 if new_topic != old_topic {
406 new_focus.topic_transitions.push(TopicTransition {
407 from_topic: old_topic.clone(),
408 to_topic: new_topic.clone(),
409 message_index: message_idx,
410 transition_keyword: "AI detected".to_string(),
411 });
412 }
413 }
414
415 new_focus
416 }
417 None => {
418 ConversationFocus {
420 current_topic: result.new_topic.clone().or(message_context.topic),
421 current_question: result.new_question.clone().or(message_context.question),
422 recent_context: result.context_to_add.clone().map(|ctx| vec![ctx]).unwrap_or_default(),
423 topic_transitions: Vec::new(),
424 detected_at: message_idx,
425 }
426 }
427 };
428
429 self.current_focus = Some(new_focus);
430 log::debug!("Focus updated: topic={}, question={}",
431 self.current_focus.as_ref().and_then(|f| f.current_topic.as_ref()).unwrap_or(&"none".to_string()),
432 self.current_focus.as_ref().and_then(|f| f.current_question.as_ref()).unwrap_or(&"none".to_string())
433 );
434 }
435
436 fn extract_message_context(&self, message: &Message) -> MessageContext {
438 let text = match &message.content {
439 MessageContent::Text(t) => t.clone(),
440 MessageContent::Blocks(blocks) => {
441 blocks.iter()
442 .filter_map(|b| {
443 if let ContentBlock::Text { text } = b {
444 Some(text.clone())
445 } else {
446 None
447 }
448 })
449 .collect::<Vec<_>>()
450 .join("\n")
451 }
452 };
453
454 let topic = self.config.find_tech_keywords(&text)
456 .first()
457 .cloned();
458
459 let question = if self.config.matches_question(&text) {
460 Some(text.chars().take(100).collect::<String>())
461 } else {
462 None
463 };
464
465 MessageContext { topic, question }
466 }
467
468 fn message_cache_key(&self, message: &Message) -> String {
470 let content = match &message.content {
471 MessageContent::Text(t) => t.clone(),
472 MessageContent::Blocks(blocks) => {
473 blocks.iter()
474 .filter_map(|b| {
475 if let ContentBlock::Text { text } = b {
476 Some(text.clone())
477 } else {
478 None
479 }
480 })
481 .collect::<Vec<_>>()
482 .join("|")
483 }
484 };
485
486 let key = content.chars().take(100).collect::<String>();
488 format!("{:?}:{}", message.role, key)
489 }
490
491 fn cache_result(&mut self, key: String, result: FocusAnalysisResult) {
493 self.analysis_cache.retain(|(k, _)| k != &key);
495
496 self.analysis_cache.push((key, result));
498
499 if self.analysis_cache.len() > self.max_cache_size {
501 self.analysis_cache.remove(0);
502 }
503 }
504
505 pub fn detect_focus_fallback(&self, messages: &[Message]) -> ConversationFocus {
509 let tracker = super::focus::FocusTracker::with_config(self.config.clone());
511 tracker.detect_focus(messages)
512 }
513
514 pub fn focus_score(&self, message: &Message) -> f32 {
519 let key = self.message_cache_key(message);
521 if let Some((_, result)) = self.analysis_cache.iter().find(|(k, _)| k == &key) {
522 return result.relevance;
523 }
524
525 if let Some(focus) = &self.current_focus {
527 let tracker = super::focus::FocusTracker::with_config(self.config.clone());
528 tracker.focus_score(message, focus)
529 } else {
530 0.5 }
532 }
533
534 pub fn create_focus_message(&self) -> Message {
536 match &self.current_focus {
537 Some(focus) => {
538 let tracker = super::focus::FocusTracker::with_config(self.config.clone());
539 tracker.create_focus_message(focus)
540 }
541 None => {
542 Message {
543 role: Role::System,
544 content: MessageContent::Text("[焦点追踪系统初始化]".to_string()),
545 }
546 }
547 }
548 }
549}
550
551struct MessageContext {
553 topic: Option<String>,
554 question: Option<String>,
555}
556
557#[cfg(test)]
558mod tests {
559 use super::*;
560
561 #[test]
562 fn test_focus_analysis_result_default() {
563 let result = FocusAnalysisResult::default();
564 assert_eq!(result.relevance, 0.5);
565 assert!(!result.is_focus_update);
566 assert!(result.new_topic.is_none());
567 assert!(result.new_question.is_none());
568 }
569
570 #[test]
571 fn test_focus_analysis_result_clamp_relevance() {
572 let json = r#"{
573 "relevance": 1.5,
574 "is_focus_update": false,
575 "reason": "test"
576 }"#;
577
578 let result: FocusAnalysisResult = serde_json::from_str(json).unwrap();
581 assert_eq!(result.relevance, 1.5); }
583
584 #[test]
585 fn test_ai_focus_tracker_creation() {
586 let tracker = AiFocusTracker::new_minimal("test-model".to_string());
587 assert!(tracker.current_focus().is_none());
588 assert!(tracker.config().validate());
589 }
590
591 #[test]
592 fn test_format_current_focus_none() {
593 let tracker = AiFocusTracker::new_minimal("test-model".to_string());
594 let text = tracker.format_current_focus();
595 assert!(text.contains("尚未建立明确焦点"));
596 }
597
598 #[test]
599 fn test_format_current_focus_some() {
600 let mut tracker = AiFocusTracker::new_minimal("test-model".to_string());
601 tracker.set_focus(ConversationFocus {
602 current_topic: Some("API设计".to_string()),
603 current_question: Some("如何优化性能?".to_string()),
604 recent_context: vec!["之前讨论了数据库".to_string()],
605 topic_transitions: Vec::new(),
606 detected_at: 5,
607 });
608
609 let text = tracker.format_current_focus();
610 assert!(text.contains("API设计"));
611 assert!(text.contains("如何优化性能"));
612 assert!(text.contains("之前讨论了数据库"));
613 }
614
615 #[test]
616 fn test_format_message() {
617 let tracker = AiFocusTracker::new_minimal("test-model".to_string());
618 let message = Message {
619 role: Role::User,
620 content: MessageContent::Text("如何优化API性能?".to_string()),
621 };
622
623 let text = tracker.format_message(&message);
624 assert!(text.contains("用户"));
625 assert!(text.contains("如何优化API性能"));
626 }
627
628 #[test]
629 fn test_format_message_truncation() {
630 let tracker = AiFocusTracker::new_minimal("test-model".to_string());
631 let long_text = "x".repeat(600);
632 let message = Message {
633 role: Role::User,
634 content: MessageContent::Text(long_text.clone()),
635 };
636
637 let text = tracker.format_message(&message);
638 assert!(text.contains("已截断"));
639 assert!(text.len() < long_text.len() + 50);
640 }
641
642 #[test]
643 fn test_message_cache_key() {
644 let tracker = AiFocusTracker::new_minimal("test-model".to_string());
645 let message = Message {
646 role: Role::User,
647 content: MessageContent::Text("测试消息内容".to_string()),
648 };
649
650 let key = tracker.message_cache_key(&message);
651 assert!(key.starts_with("User:"));
652 }
653
654 #[test]
655 fn test_cache_result() {
656 let mut tracker = AiFocusTracker::new_minimal("test-model".to_string());
657 let key = "test-key".to_string();
658 let result = FocusAnalysisResult {
659 relevance: 0.8,
660 is_focus_update: false,
661 new_topic: None,
662 new_question: None,
663 context_to_add: None,
664 reason: "test".to_string(),
665 };
666
667 tracker.cache_result(key.clone(), result.clone());
668
669 assert_eq!(tracker.analysis_cache.len(), 1);
671 assert_eq!(tracker.analysis_cache[0].0, key);
672 assert_eq!(tracker.analysis_cache[0].1.relevance, 0.8);
673 }
674
675 #[test]
676 fn test_cache_result_max_size() {
677 let mut tracker = AiFocusTracker::new_minimal("test-model".to_string());
678 tracker.max_cache_size = 3;
679
680 for i in 0..5 {
681 tracker.cache_result(
682 format!("key-{}", i),
683 FocusAnalysisResult::default(),
684 );
685 }
686
687 assert_eq!(tracker.analysis_cache.len(), 3);
689 assert!(!tracker.analysis_cache.iter().any(|(k, _)| k == "key-0"));
691 assert!(!tracker.analysis_cache.iter().any(|(k, _)| k == "key-1"));
692 }
693
694 #[test]
695 fn test_set_and_clear_focus() {
696 let mut tracker = AiFocusTracker::new_minimal("test-model".to_string());
697
698 tracker.set_focus(ConversationFocus {
699 current_topic: Some("测试话题".to_string()),
700 current_question: None,
701 recent_context: Vec::new(),
702 topic_transitions: Vec::new(),
703 detected_at: 0,
704 });
705
706 assert!(tracker.current_focus().is_some());
707
708 tracker.clear_focus();
709 assert!(tracker.current_focus().is_none());
710 }
711
712 #[test]
713 fn test_detect_focus_fallback() {
714 let tracker = AiFocusTracker::new_minimal("test-model".to_string());
715 let messages = vec![
716 Message {
717 role: Role::User,
718 content: MessageContent::Text("如何优化 Rust 性能?".to_string()),
719 },
720 ];
721
722 let focus = tracker.detect_focus_fallback(&messages);
723 assert!(focus.current_question.is_some());
724 }
725
726 #[test]
727 fn test_focus_score_without_focus() {
728 let tracker = AiFocusTracker::new_minimal("test-model".to_string());
729 let message = Message {
730 role: Role::User,
731 content: MessageContent::Text("测试消息".to_string()),
732 };
733
734 let score = tracker.focus_score(&message);
735 assert_eq!(score, 0.5); }
737
738 #[test]
739 fn test_create_focus_message_without_focus() {
740 let tracker = AiFocusTracker::new_minimal("test-model".to_string());
741 let msg = tracker.create_focus_message();
742
743 assert!(matches!(msg.role, Role::System));
744 let text = match &msg.content {
746 MessageContent::Text(t) => t.clone(),
747 MessageContent::Blocks(blocks) => {
748 blocks.iter()
749 .filter_map(|b| {
750 if let ContentBlock::Text { text } = b {
751 Some(text.clone())
752 } else {
753 None
754 }
755 })
756 .collect::<Vec<_>>()
757 .join("")
758 }
759 };
760 assert!(text.contains("初始化"));
761 }
762
763 #[test]
764 fn test_create_focus_message_with_focus() {
765 let mut tracker = AiFocusTracker::new_minimal("test-model".to_string());
766 tracker.set_focus(ConversationFocus {
767 current_topic: Some("API优化".to_string()),
768 current_question: Some("如何提升性能?".to_string()),
769 recent_context: Vec::new(),
770 topic_transitions: Vec::new(),
771 detected_at: 5,
772 });
773
774 let msg = tracker.create_focus_message();
775 assert!(matches!(msg.role, Role::System));
776 let text = match &msg.content {
778 MessageContent::Text(t) => t.clone(),
779 MessageContent::Blocks(blocks) => {
780 blocks.iter()
781 .filter_map(|b| {
782 if let ContentBlock::Text { text } = b {
783 Some(text.clone())
784 } else {
785 None
786 }
787 })
788 .collect::<Vec<_>>()
789 .join("")
790 }
791 };
792 assert!(text.contains("API优化"));
793 assert!(text.contains("如何提升性能"));
794 }
795
796 #[test]
797 fn test_parse_analysis_result_valid() {
798 let tracker = AiFocusTracker::new_minimal("test-model".to_string());
799 let json = r#"{
800 "relevance": 0.8,
801 "is_focus_update": false,
802 "reason": "高度相关"
803 }"#;
804
805 let result = tracker.parse_analysis_result(json).unwrap();
806 assert_eq!(result.relevance, 0.8);
807 assert!(!result.is_focus_update);
808 assert_eq!(result.reason, "高度相关");
809 }
810
811 #[test]
812 fn test_parse_analysis_result_with_code_block() {
813 let tracker = AiFocusTracker::new_minimal("test-model".to_string());
814 let json = r#"```json
815{
816 "relevance": 0.7,
817 "is_focus_update": true,
818 "new_topic": "新话题",
819 "reason": "话题切换"
820}
821```"#;
822
823 let result = tracker.parse_analysis_result(json).unwrap();
824 assert_eq!(result.relevance, 0.7);
825 assert!(result.is_focus_update);
826 assert_eq!(result.new_topic, Some("新话题".to_string()));
827 }
828}