Skip to main content

matrixcode_core/compress/
focus_extractor.rs

1use crate::providers::{Message, MessageContent, Role};
2use crate::memory::MemoryEntry;
3use super::focus_point::{FocusPoint, FocusStatus};
4use super::prompts_zh::{EXTRACTION_PROMPT, CLASSIFICATION_PROMPT};
5use chrono::Utc;
6use std::collections::HashMap;
7
8/// AI 聚焦点提取器
9/// 
10/// 负责:
11/// 1. 从对话历史中提取聚焦点
12/// 2. 判断用户输入属于哪个聚焦点
13/// 3. 创建新聚焦点
14/// 4. 更新聚焦点状态
15pub struct FocusExtractor;
16
17impl FocusExtractor {
18    /// 从记忆条目中提取聚焦点
19    /// 
20    /// AI 在提取记忆时,同时提取聚焦点信息
21    pub fn extract_from_memory(memory: &MemoryEntry) -> Option<FocusPoint> {
22        // 从记忆的 tags 构建 keywords
23        if memory.tags.is_empty() {
24            return None;
25        }
26        
27        // 从 content 提取主题(简单实现:取第一句)
28        let topic = memory.content
29            .split('\n')
30            .next()
31            .unwrap_or(&memory.content)
32            .to_string();
33        
34        Some(FocusPoint::new(
35            format!("focus-{}", memory.id),
36            topic,
37            memory.tags.clone(),
38            vec![],
39            None,
40            0,
41        ).with_importance((memory.importance / 100.0) as f32))
42    }
43    
44    /// 从对话历史提取聚焦点(AI 分析)
45    /// 
46    /// 返回 JSON 格式的聚焦点信息,供 AI 模型解析
47    pub fn create_extraction_prompt(messages: &[Message]) -> String {
48        let conversation = Self::format_conversation(messages);
49        EXTRACTION_PROMPT.replace("{conversation}", &conversation)
50    }
51    
52    /// 从用户输入判断聚焦点归属
53    pub fn create_classification_prompt(user_input: &str, existing_foci: &[FocusPoint]) -> String {
54        let foci_description = Self::format_existing_foci(existing_foci);
55        CLASSIFICATION_PROMPT
56            .replace("{user_input}", user_input)
57            .replace("{foci_description}", &foci_description)
58    }
59    
60    /// 格式化对话历史
61    fn format_conversation(messages: &[Message]) -> String {
62        messages.iter()
63            .map(|msg| {
64                let role = match msg.role {
65                    Role::User => "User",
66                    Role::Assistant => "AI",
67                    Role::System => "System",
68                    Role::Tool => "Tool",
69                };
70                
71                let content = match &msg.content {
72                    MessageContent::Text(text) => text.clone(),
73                    MessageContent::Blocks(blocks) => {
74                        blocks.iter()
75                            .filter_map(|b| {
76                                match b {
77                                    crate::providers::ContentBlock::Text { text } => Some(text.clone()),
78                                    _ => None,
79                                }
80                            })
81                            .collect::<Vec<_>>()
82                            .join("\n")
83                    }
84                };
85                
86                format!("{}: {}", role, content)
87            })
88            .collect::<Vec<_>>()
89            .join("\n\n")
90    }
91    
92    /// 格式化现有聚焦点
93    fn format_existing_foci(foci: &[FocusPoint]) -> String {
94        foci.iter()
95            .map(|f| {
96                format!(
97                    "- ID: {}\n  Topic: {}\n  Keywords: {}\n  Entities: {}\n  Status: {}\n  Importance: {}",
98                    f.id,
99                    f.topic,
100                    f.keywords.join(", "),
101                    f.entities.join(", "),
102                    f.status,
103                    f.importance
104                )
105            })
106            .collect::<Vec<_>>()
107            .join("\n\n")
108    }
109    
110    /// 解析 AI 返回的聚焦点 JSON
111    pub fn parse_focus_response(response: &str) -> Result<Vec<FocusPoint>, String> {
112        // 尝试提取 JSON 部分
113        let json_str = Self::extract_json(response)?;
114        
115        let parsed: serde_json::Value = serde_json::from_str(&json_str)
116            .map_err(|e| format!("JSON parse error: {}", e))?;
117        
118        let focuses = parsed["focuses"]
119            .as_array()
120            .ok_or("No focuses array in response")?;
121        
122        let mut result = Vec::new();
123        
124        for focus_json in focuses {
125            let importance = focus_json["importance"]
126                .as_f64()
127                .unwrap_or(0.7) as f32;
128            
129            let focus = FocusPoint::new(
130                format!("focus-{}", Utc::now().timestamp()),
131                focus_json["topic"]
132                    .as_str()
133                    .ok_or("Missing topic")?
134                    .to_string(),
135                focus_json["keywords"]
136                    .as_array()
137                    .map(|arr| arr.iter().filter_map(|v| v.as_str().map(String::from)).collect())
138                    .unwrap_or_default(),
139                focus_json["entities"]
140                    .as_array()
141                    .map(|arr| arr.iter().filter_map(|v| v.as_str().map(String::from)).collect())
142                    .unwrap_or_default(),
143                focus_json["core_question"]
144                    .as_str()
145                    .map(String::from),
146                0,
147            ).with_importance(importance);
148            
149            // Set status
150            if focus_json["is_current"].as_bool().unwrap_or(false) {
151                result.push(focus);
152            } else {
153                let mut f = focus;
154                f.status = FocusStatus::Suspended;
155                result.push(f);
156            }
157        }
158        
159        Ok(result)
160    }
161    
162    /// 解析分类响应
163    pub fn parse_classification_response(response: &str) -> Result<ClassificationResult, String> {
164        let json_str = Self::extract_json(response)?;
165        
166        let parsed: serde_json::Value = serde_json::from_str(&json_str)
167            .map_err(|e| format!("JSON parse error: {}", e))?;
168        
169        let classification = &parsed["classification"];
170        
171        let matched_focus_id = classification["matched_focus_id"]
172            .as_str()
173            .map(String::from);
174        
175        let relevance_scores = classification["relevance_scores"]
176            .as_object()
177            .map(|obj| {
178                obj.iter()
179                    .filter_map(|(k, v)| {
180                        v.as_f64().map(|score| (k.clone(), score as f32))
181                    })
182                    .collect()
183            })
184            .unwrap_or_default();
185        
186        let is_new_focus = classification["is_new_focus"]
187            .as_bool()
188            .unwrap_or(false);
189        
190        let new_focus = if is_new_focus {
191            let new_focus_json = &parsed["new_focus"];
192            
193            Some(FocusPoint::new(
194                format!("focus-{}", Utc::now().timestamp()),
195                new_focus_json["topic"]
196                    .as_str()
197                    .ok_or("Missing new focus topic")?
198                    .to_string(),
199                new_focus_json["keywords"]
200                    .as_array()
201                    .map(|arr| arr.iter().filter_map(|v| v.as_str().map(String::from)).collect())
202                    .unwrap_or_default(),
203                new_focus_json["entities"]
204                    .as_array()
205                    .map(|arr| arr.iter().filter_map(|v| v.as_str().map(String::from)).collect())
206                    .unwrap_or_default(),
207                new_focus_json["core_question"]
208                    .as_str()
209                    .map(String::from),
210                0,
211            ))
212        } else {
213            None
214        };
215        
216        Ok(ClassificationResult {
217            matched_focus_id,
218            relevance_scores,
219            is_new_focus,
220            new_focus,
221        })
222    }
223    
224    /// 从响应中提取 JSON
225    fn extract_json(response: &str) -> Result<String, String> {
226        // 尝试找到 JSON 块
227        let start = response.find('{')
228            .ok_or("No JSON found in response")?;
229        
230        let mut end = start;
231        let mut depth = 0;
232        
233        for (idx, ch) in response[start..].chars().enumerate() {
234            if ch == '{' {
235                depth += 1;
236            } else if ch == '}' {
237                depth -= 1;
238                if depth == 0 {
239                    end = start + idx + 1;
240                    break;
241                }
242            }
243        }
244        
245        Ok(response[start..end].to_string())
246    }
247}
248
249/// 分类结果
250#[derive(Debug, Clone)]
251pub struct ClassificationResult {
252    /// 匹配的聚焦点 ID
253    pub matched_focus_id: Option<String>,
254    
255    /// 各聚焦点的相关性分数
256    pub relevance_scores: HashMap<String, f32>,
257    
258    /// 是否是新聚焦点
259    pub is_new_focus: bool,
260    
261    /// 新聚焦点(如果是)
262    pub new_focus: Option<FocusPoint>,
263}
264
265#[cfg(test)]
266mod tests {
267    use super::*;
268    
269    #[test]
270    fn test_create_extraction_prompt() {
271        let messages = vec![
272            Message {
273                role: Role::User,
274                content: MessageContent::Text("How to optimize Rust performance?".to_string()),
275            },
276            Message {
277                role: Role::Assistant,
278                content: MessageContent::Text("Use profiling tools.".to_string()),
279            },
280        ];
281        
282        let prompt = FocusExtractor::create_extraction_prompt(&messages);
283        
284        assert!(prompt.contains("分析对话内容并提取聚焦点"));
285        assert!(prompt.contains("optimize Rust performance"));
286        assert!(prompt.contains("\"focuses\":"));
287    }
288    
289    #[test]
290    fn test_parse_focus_response() {
291        let response = r#"Based on the conversation, here are the focus points:
292{
293  "focuses": [
294    {
295      "topic": "Optimizing Rust performance",
296      "keywords": ["performance", "rust", "optimization"],
297      "entities": ["main.rs", "benchmark"],
298      "core_question": "How to improve performance?",
299      "importance": 0.85,
300      "is_current": true
301    }
302  ]
303}
304"#;
305        
306        let result = FocusExtractor::parse_focus_response(response);
307        
308        assert!(result.is_ok());
309        let focuses = result.unwrap();
310        assert_eq!(focuses.len(), 1);
311        assert_eq!(focuses[0].topic, "Optimizing Rust performance");
312        assert_eq!(focuses[0].keywords.len(), 3);
313    }
314    
315    #[test]
316    fn test_create_classification_prompt() {
317        let existing_foci = vec![
318            FocusPoint::new(
319                "focus-1".to_string(),
320                "Database optimization".to_string(),
321                vec!["database".to_string()],
322                vec!["db.rs".to_string()],
323                Some("Why is query slow?".to_string()),
324                0,
325            ),
326        ];
327        
328        let prompt = FocusExtractor::create_classification_prompt(
329            "The database query is still slow",
330            &existing_foci
331        );
332        
333        assert!(prompt.contains("判断用户输入属于哪个聚焦点"));
334        assert!(prompt.contains("Database optimization"));
335        assert!(prompt.contains("\"relevance_scores\":"));
336    }
337    
338    #[test]
339    fn test_parse_classification_response() {
340        let response = r#"Classification result:
341{
342  "classification": {
343    "matched_focus_id": "focus-1",
344    "relevance_scores": {
345      "focus-1": 0.85
346    },
347    "is_new_focus": false,
348    "reason": "Input mentions database"
349  }
350}
351"#;
352        
353        let result = FocusExtractor::parse_classification_response(response);
354        
355        assert!(result.is_ok());
356        let classification = result.unwrap();
357        assert_eq!(classification.matched_focus_id, Some("focus-1".to_string()));
358        assert!(!classification.is_new_focus);
359    }
360}