Skip to main content

matrixcode_core/compress/
complexity.rs

1//! Conversation Complexity Analyzer
2//! 
3//! Analyzes conversation complexity to adapt compression thresholds dynamically.
4
5use crate::providers::{Message, MessageContent};
6
7/// Conversation complexity level
8#[derive(Debug, Clone, Copy, PartialEq, Eq)]
9pub enum ComplexityLevel {
10    /// Technical discussion with code, errors, tools
11    High,
12    /// Mixed technical and general conversation
13    Medium,
14    /// Casual conversation, simple questions
15    Low,
16}
17
18/// Complexity analyzer configuration
19#[derive(Debug, Clone)]
20pub struct ComplexityConfig {
21    /// Weight for code presence
22    code_weight: f32,
23    /// Weight for tool usage
24    tool_weight: f32,
25    /// Weight for technical keywords
26    keyword_weight: f32,
27    /// Weight for error mentions
28    error_weight: f32,
29    /// Threshold for high complexity
30    high_threshold: f32,
31    /// Threshold for medium complexity
32    medium_threshold: f32,
33}
34
35impl Default for ComplexityConfig {
36    fn default() -> Self {
37        Self {
38            code_weight: 0.3,
39            tool_weight: 0.25,
40            keyword_weight: 0.15,
41            error_weight: 0.2,
42            high_threshold: 5.0,
43            medium_threshold: 2.0,
44        }
45    }
46}
47
48/// Analyzes conversation complexity
49pub struct ComplexityAnalyzer {
50    config: ComplexityConfig,
51    /// Technical keywords (中英文)
52    tech_keywords: Vec<String>,
53}
54
55impl Default for ComplexityAnalyzer {
56    fn default() -> Self {
57        Self::new(ComplexityConfig::default())
58    }
59}
60
61impl ComplexityAnalyzer {
62    pub fn new(config: ComplexityConfig) -> Self {
63        Self {
64            config,
65            tech_keywords: vec![
66                // 中文关键词
67                "函数".to_string(),
68                "优化".to_string(),
69                "性能".to_string(),
70                "错误".to_string(),
71                "测试".to_string(),
72                "架构".to_string(),
73                "数据库".to_string(),
74                "算法".to_string(),
75                "重构".to_string(),
76                "调试".to_string(),
77                "部署".to_string(),
78                "缓存".to_string(),
79                "并发".to_string(),
80                "异步".to_string(),
81                // 英文关键词
82                "function".to_string(),
83                "optimize".to_string(),
84                "performance".to_string(),
85                "error".to_string(),
86                "test".to_string(),
87                "architecture".to_string(),
88                "database".to_string(),
89                "algorithm".to_string(),
90                "refactor".to_string(),
91                "debug".to_string(),
92                "deploy".to_string(),
93                "cache".to_string(),
94                "async".to_string(),
95                "concurrent".to_string(),
96            ],
97        }
98    }
99
100    /// Analyze conversation complexity
101    pub fn analyze(messages: &[Message]) -> ComplexityLevel {
102        let analyzer = Self::default();
103        analyzer.analyze_complexity(messages)
104    }
105
106    /// Calculate complexity score
107    pub fn analyze_complexity(&self, messages: &[Message]) -> ComplexityLevel {
108        if messages.is_empty() {
109            return ComplexityLevel::Low;
110        }
111
112        let mut score = 0.0;
113
114        // 1. Code density detection
115        let code_count = messages.iter()
116            .filter(|m| self.has_code(m))
117            .count();
118        score += code_count as f32 * self.config.code_weight;
119
120        // 2. Tool usage frequency
121        let tool_count = messages.iter()
122            .filter(|m| self.has_tool_use(m))
123            .count();
124        score += tool_count as f32 * self.config.tool_weight;
125
126        // 3. Technical keyword density
127        let keyword_hits = messages.iter()
128            .map(|m| self.count_keywords(m))
129            .sum::<usize>();
130        score += keyword_hits as f32 * self.config.keyword_weight;
131
132        // 4. Error mentions
133        let error_count = messages.iter()
134            .filter(|m| self.has_error(m))
135            .count();
136        score += error_count as f32 * self.config.error_weight;
137
138        // Normalize by message count
139        score /= messages.len() as f32;
140
141        // Determine level
142        if score >= self.config.high_threshold {
143            ComplexityLevel::High
144        } else if score >= self.config.medium_threshold {
145            ComplexityLevel::Medium
146        } else {
147            ComplexityLevel::Low
148        }
149    }
150
151    /// Check if message contains code
152    fn has_code(&self, message: &Message) -> bool {
153        let content = self.get_text_content(message);
154        content.contains("```") || 
155            content.contains("fn ") || 
156            content.contains("function ") ||
157            content.contains("class ") ||
158            content.contains("struct ")
159    }
160
161    /// Check if message has tool use
162    fn has_tool_use(&self, message: &Message) -> bool {
163        matches!(message.content, MessageContent::Blocks(_)) ||
164            self.get_text_content(message).contains("tool") ||
165            self.get_text_content(message).contains("工具")
166    }
167
168    /// Check if message mentions error
169    fn has_error(&self, message: &Message) -> bool {
170        let content = self.get_text_content(message);
171        content.contains("error") ||
172            content.contains("failed") ||
173            content.contains("错误") ||
174            content.contains("失败") ||
175            content.contains("异常") ||
176            content.contains("exception")
177    }
178
179    /// Count technical keywords in message
180    fn count_keywords(&self, message: &Message) -> usize {
181        let content = self.get_text_content(message).to_lowercase();
182        self.tech_keywords.iter()
183            .filter(|kw| content.contains(&kw.to_lowercase()))
184            .count()
185    }
186
187    /// Extract text content from message
188    fn get_text_content(&self, message: &Message) -> String {
189        match &message.content {
190            MessageContent::Text(text) => text.clone(),
191            MessageContent::Blocks(blocks) => {
192                blocks.iter()
193                    .filter_map(|block| {
194                        if let crate::providers::ContentBlock::Text { text } = block {
195                            Some(text.clone())
196                        } else {
197                            None
198                        }
199                    })
200                    .collect::<Vec<_>>()
201                    .join("\n")
202            }
203        }
204    }
205
206    /// Get complexity description
207    pub fn complexity_description(level: ComplexityLevel) -> &'static str {
208        match level {
209            ComplexityLevel::High => "技术讨论密集:大量代码、工具使用、错误处理",
210            ComplexityLevel::Medium => "混合对话:部分技术内容",
211            ComplexityLevel::Low => "简单对话:少量技术内容",
212        }
213    }
214}
215
216#[cfg(test)]
217mod tests {
218    use super::*;
219    use crate::providers::Role;
220
221    #[test]
222    fn test_empty_messages() {
223        let level = ComplexityAnalyzer::analyze(&[]);
224        assert_eq!(level, ComplexityLevel::Low);
225    }
226
227    #[test]
228    fn test_high_complexity() {
229        let messages = vec![
230            Message {
231                role: Role::User,
232                content: MessageContent::Text("这个函数性能有问题,需要优化算法".to_string()),
233            },
234            Message {
235                role: Role::Assistant,
236                content: MessageContent::Text("好的,我来优化这个函数:\n```rust\nfn optimize() {}\n```".to_string()),
237            },
238            Message {
239                role: Role::User,
240                content: MessageContent::Text("测试失败了,出现错误".to_string()),
241            },
242        ];
243        
244        let level = ComplexityAnalyzer::analyze(&messages);
245        assert_eq!(level, ComplexityLevel::High);
246    }
247
248    #[test]
249    fn test_medium_complexity() {
250        let messages = vec![
251            Message {
252                role: Role::User,
253                content: MessageContent::Text("如何在数据库中查询数据?".to_string()),
254            },
255            Message {
256                role: Role::Assistant,
257                content: MessageContent::Text("你可以使用 SQL 查询".to_string()),
258            },
259        ];
260        
261        let level = ComplexityAnalyzer::analyze(&messages);
262        assert_eq!(level, ComplexityLevel::Medium);
263    }
264
265    #[test]
266    fn test_low_complexity() {
267        let messages = vec![
268            Message {
269                role: Role::User,
270                content: MessageContent::Text("你好".to_string()),
271            },
272            Message {
273                role: Role::Assistant,
274                content: MessageContent::Text("你好!有什么可以帮助你的?".to_string()),
275            },
276        ];
277        
278        let level = ComplexityAnalyzer::analyze(&messages);
279        assert_eq!(level, ComplexityLevel::Low);
280    }
281}