matrixcode_core/compress/
focus_extractor.rs1use crate::providers::{Message, MessageContent, Role};
2use crate::memory::MemoryEntry;
3use super::focus_point::{FocusPoint, FocusStatus};
4use super::prompts_zh::{EXTRACTION_PROMPT, CLASSIFICATION_PROMPT};
5use chrono::Utc;
6use std::collections::HashMap;
7
8pub struct FocusExtractor;
16
17impl FocusExtractor {
18 pub fn extract_from_memory(memory: &MemoryEntry) -> Option<FocusPoint> {
22 if memory.tags.is_empty() {
24 return None;
25 }
26
27 let topic = memory.content
29 .split('\n')
30 .next()
31 .unwrap_or(&memory.content)
32 .to_string();
33
34 Some(FocusPoint::new(
35 format!("focus-{}", memory.id),
36 topic,
37 memory.tags.clone(),
38 vec![],
39 None,
40 0,
41 ).with_importance((memory.importance / 100.0) as f32))
42 }
43
44 pub fn create_extraction_prompt(messages: &[Message]) -> String {
48 let conversation = Self::format_conversation(messages);
49 EXTRACTION_PROMPT.replace("{conversation}", &conversation)
50 }
51
52 pub fn create_classification_prompt(user_input: &str, existing_foci: &[FocusPoint]) -> String {
54 let foci_description = Self::format_existing_foci(existing_foci);
55 CLASSIFICATION_PROMPT
56 .replace("{user_input}", user_input)
57 .replace("{foci_description}", &foci_description)
58 }
59
60 fn format_conversation(messages: &[Message]) -> String {
62 messages.iter()
63 .map(|msg| {
64 let role = match msg.role {
65 Role::User => "User",
66 Role::Assistant => "AI",
67 Role::System => "System",
68 Role::Tool => "Tool",
69 };
70
71 let content = match &msg.content {
72 MessageContent::Text(text) => text.clone(),
73 MessageContent::Blocks(blocks) => {
74 blocks.iter()
75 .filter_map(|b| {
76 match b {
77 crate::providers::ContentBlock::Text { text } => Some(text.clone()),
78 _ => None,
79 }
80 })
81 .collect::<Vec<_>>()
82 .join("\n")
83 }
84 };
85
86 format!("{}: {}", role, content)
87 })
88 .collect::<Vec<_>>()
89 .join("\n\n")
90 }
91
92 fn format_existing_foci(foci: &[FocusPoint]) -> String {
94 foci.iter()
95 .map(|f| {
96 format!(
97 "- ID: {}\n Topic: {}\n Keywords: {}\n Entities: {}\n Status: {}\n Importance: {}",
98 f.id,
99 f.topic,
100 f.keywords.join(", "),
101 f.entities.join(", "),
102 f.status,
103 f.importance
104 )
105 })
106 .collect::<Vec<_>>()
107 .join("\n\n")
108 }
109
110 pub fn parse_focus_response(response: &str) -> Result<Vec<FocusPoint>, String> {
112 let json_str = Self::extract_json(response)?;
114
115 let parsed: serde_json::Value = serde_json::from_str(&json_str)
116 .map_err(|e| format!("JSON parse error: {}", e))?;
117
118 let focuses = parsed["focuses"]
119 .as_array()
120 .ok_or("No focuses array in response")?;
121
122 let mut result = Vec::new();
123
124 for focus_json in focuses {
125 let importance = focus_json["importance"]
126 .as_f64()
127 .unwrap_or(0.7) as f32;
128
129 let focus = FocusPoint::new(
130 format!("focus-{}", Utc::now().timestamp()),
131 focus_json["topic"]
132 .as_str()
133 .ok_or("Missing topic")?
134 .to_string(),
135 focus_json["keywords"]
136 .as_array()
137 .map(|arr| arr.iter().filter_map(|v| v.as_str().map(String::from)).collect())
138 .unwrap_or_default(),
139 focus_json["entities"]
140 .as_array()
141 .map(|arr| arr.iter().filter_map(|v| v.as_str().map(String::from)).collect())
142 .unwrap_or_default(),
143 focus_json["core_question"]
144 .as_str()
145 .map(String::from),
146 0,
147 ).with_importance(importance);
148
149 if focus_json["is_current"].as_bool().unwrap_or(false) {
151 result.push(focus);
152 } else {
153 let mut f = focus;
154 f.status = FocusStatus::Suspended;
155 result.push(f);
156 }
157 }
158
159 Ok(result)
160 }
161
162 pub fn parse_classification_response(response: &str) -> Result<ClassificationResult, String> {
164 let json_str = Self::extract_json(response)?;
165
166 let parsed: serde_json::Value = serde_json::from_str(&json_str)
167 .map_err(|e| format!("JSON parse error: {}", e))?;
168
169 let classification = &parsed["classification"];
170
171 let matched_focus_id = classification["matched_focus_id"]
172 .as_str()
173 .map(String::from);
174
175 let relevance_scores = classification["relevance_scores"]
176 .as_object()
177 .map(|obj| {
178 obj.iter()
179 .filter_map(|(k, v)| {
180 v.as_f64().map(|score| (k.clone(), score as f32))
181 })
182 .collect()
183 })
184 .unwrap_or_default();
185
186 let is_new_focus = classification["is_new_focus"]
187 .as_bool()
188 .unwrap_or(false);
189
190 let new_focus = if is_new_focus {
191 let new_focus_json = &parsed["new_focus"];
192
193 Some(FocusPoint::new(
194 format!("focus-{}", Utc::now().timestamp()),
195 new_focus_json["topic"]
196 .as_str()
197 .ok_or("Missing new focus topic")?
198 .to_string(),
199 new_focus_json["keywords"]
200 .as_array()
201 .map(|arr| arr.iter().filter_map(|v| v.as_str().map(String::from)).collect())
202 .unwrap_or_default(),
203 new_focus_json["entities"]
204 .as_array()
205 .map(|arr| arr.iter().filter_map(|v| v.as_str().map(String::from)).collect())
206 .unwrap_or_default(),
207 new_focus_json["core_question"]
208 .as_str()
209 .map(String::from),
210 0,
211 ))
212 } else {
213 None
214 };
215
216 Ok(ClassificationResult {
217 matched_focus_id,
218 relevance_scores,
219 is_new_focus,
220 new_focus,
221 })
222 }
223
224 fn extract_json(response: &str) -> Result<String, String> {
226 let start = response.find('{')
228 .ok_or("No JSON found in response")?;
229
230 let mut end = start;
231 let mut depth = 0;
232
233 for (idx, ch) in response[start..].chars().enumerate() {
234 if ch == '{' {
235 depth += 1;
236 } else if ch == '}' {
237 depth -= 1;
238 if depth == 0 {
239 end = start + idx + 1;
240 break;
241 }
242 }
243 }
244
245 Ok(response[start..end].to_string())
246 }
247}
248
249#[derive(Debug, Clone)]
251pub struct ClassificationResult {
252 pub matched_focus_id: Option<String>,
254
255 pub relevance_scores: HashMap<String, f32>,
257
258 pub is_new_focus: bool,
260
261 pub new_focus: Option<FocusPoint>,
263}
264
265#[cfg(test)]
266mod tests {
267 use super::*;
268
269 #[test]
270 fn test_create_extraction_prompt() {
271 let messages = vec![
272 Message {
273 role: Role::User,
274 content: MessageContent::Text("How to optimize Rust performance?".to_string()),
275 },
276 Message {
277 role: Role::Assistant,
278 content: MessageContent::Text("Use profiling tools.".to_string()),
279 },
280 ];
281
282 let prompt = FocusExtractor::create_extraction_prompt(&messages);
283
284 assert!(prompt.contains("分析对话内容并提取聚焦点"));
285 assert!(prompt.contains("optimize Rust performance"));
286 assert!(prompt.contains("\"focuses\":"));
287 }
288
289 #[test]
290 fn test_parse_focus_response() {
291 let response = r#"Based on the conversation, here are the focus points:
292{
293 "focuses": [
294 {
295 "topic": "Optimizing Rust performance",
296 "keywords": ["performance", "rust", "optimization"],
297 "entities": ["main.rs", "benchmark"],
298 "core_question": "How to improve performance?",
299 "importance": 0.85,
300 "is_current": true
301 }
302 ]
303}
304"#;
305
306 let result = FocusExtractor::parse_focus_response(response);
307
308 assert!(result.is_ok());
309 let focuses = result.unwrap();
310 assert_eq!(focuses.len(), 1);
311 assert_eq!(focuses[0].topic, "Optimizing Rust performance");
312 assert_eq!(focuses[0].keywords.len(), 3);
313 }
314
315 #[test]
316 fn test_create_classification_prompt() {
317 let existing_foci = vec![
318 FocusPoint::new(
319 "focus-1".to_string(),
320 "Database optimization".to_string(),
321 vec!["database".to_string()],
322 vec!["db.rs".to_string()],
323 Some("Why is query slow?".to_string()),
324 0,
325 ),
326 ];
327
328 let prompt = FocusExtractor::create_classification_prompt(
329 "The database query is still slow",
330 &existing_foci
331 );
332
333 assert!(prompt.contains("判断用户输入属于哪个聚焦点"));
334 assert!(prompt.contains("Database optimization"));
335 assert!(prompt.contains("\"relevance_scores\":"));
336 }
337
338 #[test]
339 fn test_parse_classification_response() {
340 let response = r#"Classification result:
341{
342 "classification": {
343 "matched_focus_id": "focus-1",
344 "relevance_scores": {
345 "focus-1": 0.85
346 },
347 "is_new_focus": false,
348 "reason": "Input mentions database"
349 }
350}
351"#;
352
353 let result = FocusExtractor::parse_classification_response(response);
354
355 assert!(result.is_ok());
356 let classification = result.unwrap();
357 assert_eq!(classification.matched_focus_id, Some("focus-1".to_string()));
358 assert!(!classification.is_new_focus);
359 }
360}