Skip to main content

matrixcode_core/compress/
complexity.rs

1//! Conversation Complexity Analyzer
2//! 
3//! Analyzes conversation complexity to adapt compression thresholds dynamically.
4
5use crate::providers::{Message, MessageContent};
6
7/// Conversation complexity level
8#[derive(Debug, Clone, Copy, PartialEq, Eq)]
9pub enum ComplexityLevel {
10    /// Technical discussion with code, errors, tools
11    High,
12    /// Mixed technical and general conversation
13    Medium,
14    /// Casual conversation, simple questions
15    Low,
16}
17
18/// Complexity analyzer configuration
19#[derive(Debug, Clone)]
20pub struct ComplexityConfig {
21    /// Weight for code presence
22    code_weight: f32,
23    /// Weight for tool usage
24    tool_weight: f32,
25    /// Weight for technical keywords
26    keyword_weight: f32,
27    /// Weight for error mentions
28    error_weight: f32,
29    /// Threshold for high complexity
30    high_threshold: f32,
31    /// Threshold for medium complexity
32    medium_threshold: f32,
33}
34
35impl Default for ComplexityConfig {
36    fn default() -> Self {
37        Self {
38            code_weight: 0.3,
39            tool_weight: 0.25,
40            keyword_weight: 0.15,
41            error_weight: 0.2,
42            high_threshold: 5.0,
43            medium_threshold: 2.0,
44        }
45    }
46}
47
48/// Analyzes conversation complexity
49pub struct ComplexityAnalyzer {
50    config: ComplexityConfig,
51    /// Technical keywords (中英文)
52    tech_keywords: Vec<String>,
53}
54
55impl Default for ComplexityAnalyzer {
56    fn default() -> Self {
57        Self::new(ComplexityConfig::default())
58    }
59}
60
61impl ComplexityAnalyzer {
62    pub fn new(config: ComplexityConfig) -> Self {
63        Self {
64            config,
65            tech_keywords: vec![
66                // 中文关键词
67                "函数".to_string(),
68                "优化".to_string(),
69                "性能".to_string(),
70                "错误".to_string(),
71                "测试".to_string(),
72                "架构".to_string(),
73                "数据库".to_string(),
74                "算法".to_string(),
75                "重构".to_string(),
76                "调试".to_string(),
77                "部署".to_string(),
78                "缓存".to_string(),
79                "并发".to_string(),
80                "异步".to_string(),
81                // 英文关键词
82                "function".to_string(),
83                "optimize".to_string(),
84                "performance".to_string(),
85                "error".to_string(),
86                "test".to_string(),
87                "architecture".to_string(),
88                "database".to_string(),
89                "algorithm".to_string(),
90                "refactor".to_string(),
91                "debug".to_string(),
92                "deploy".to_string(),
93                "cache".to_string(),
94                "async".to_string(),
95                "concurrent".to_string(),
96            ],
97        }
98    }
99
100    /// Analyze conversation complexity
101    pub fn analyze(messages: &[Message]) -> ComplexityLevel {
102        let analyzer = Self::default();
103        analyzer.analyze_complexity(messages)
104    }
105
106    /// Calculate complexity score
107    pub fn analyze_complexity(&self, messages: &[Message]) -> ComplexityLevel {
108        if messages.is_empty() {
109            return ComplexityLevel::Low;
110        }
111
112        let mut score = 0.0;
113
114        // 1. Code density detection
115        let code_count = messages.iter()
116            .filter(|m| self.has_code(m))
117            .count();
118        score += code_count as f32 * self.config.code_weight;
119
120        // 2. Tool usage frequency
121        let tool_count = messages.iter()
122            .filter(|m| self.has_tool_use(m))
123            .count();
124        score += tool_count as f32 * self.config.tool_weight;
125
126        // 3. Technical keyword density
127        let keyword_hits = messages.iter()
128            .map(|m| self.count_keywords(m))
129            .sum::<usize>();
130        score += keyword_hits as f32 * self.config.keyword_weight;
131
132        // 4. Error mentions
133        let error_count = messages.iter()
134            .filter(|m| self.has_error(m))
135            .count();
136        score += error_count as f32 * self.config.error_weight;
137
138        // Normalize by message count
139        score /= messages.len() as f32;
140
141        // Determine level
142        if score >= self.config.high_threshold {
143            ComplexityLevel::High
144        } else if score >= self.config.medium_threshold {
145            ComplexityLevel::Medium
146        } else {
147            ComplexityLevel::Low
148        }
149    }
150
151    /// Check if message contains code
152    fn has_code(&self, message: &Message) -> bool {
153        let content = self.get_text_content(message);
154        content.contains("```") || 
155            content.contains("fn ") || 
156            content.contains("function ") ||
157            content.contains("class ") ||
158            content.contains("struct ")
159    }
160
161    /// Check if message has tool use
162    fn has_tool_use(&self, message: &Message) -> bool {
163        matches!(message.content, MessageContent::Blocks(_)) ||
164            self.get_text_content(message).contains("tool") ||
165            self.get_text_content(message).contains("工具")
166    }
167
168    /// Check if message mentions error
169    fn has_error(&self, message: &Message) -> bool {
170        let content = self.get_text_content(message);
171        content.contains("error") ||
172            content.contains("failed") ||
173            content.contains("错误") ||
174            content.contains("失败") ||
175            content.contains("异常") ||
176            content.contains("exception")
177    }
178
179    /// Count technical keywords in message
180    fn count_keywords(&self, message: &Message) -> usize {
181        let content = self.get_text_content(message).to_lowercase();
182        self.tech_keywords.iter()
183            .filter(|kw| content.contains(&kw.to_lowercase()))
184            .count()
185    }
186
187    /// Extract text content from message
188    fn get_text_content(&self, message: &Message) -> String {
189        match &message.content {
190            MessageContent::Text(text) => text.clone(),
191            MessageContent::Blocks(blocks) => {
192                blocks.iter()
193                    .filter_map(|block| {
194                        if let crate::providers::ContentBlock::Text { text } = block {
195                            Some(text.clone())
196                        } else {
197                            None
198                        }
199                    })
200                    .collect::<Vec<_>>()
201                    .join("\n")
202            }
203        }
204    }
205
206    /// Get complexity description
207    pub fn complexity_description(level: ComplexityLevel) -> &'static str {
208        match level {
209            ComplexityLevel::High => "技术讨论密集:大量代码、工具使用、错误处理",
210            ComplexityLevel::Medium => "混合对话:部分技术内容",
211            ComplexityLevel::Low => "简单对话:少量技术内容",
212        }
213    }
214}
215
216#[cfg(test)]
217mod tests {
218    use super::*;
219    use crate::providers::Role;
220
221    #[test]
222    fn test_empty_messages() {
223        let level = ComplexityAnalyzer::analyze(&[]);
224        assert_eq!(level, ComplexityLevel::Low);
225    }
226
227    #[test]
228    fn test_high_complexity() {
229        // Use analyzer with lower thresholds for testing
230        let config = ComplexityConfig {
231            high_threshold: 0.5,
232            medium_threshold: 0.3,
233            ..Default::default()
234        };
235        let analyzer = ComplexityAnalyzer::new(config);
236
237        let messages = vec![
238            Message {
239                role: Role::User,
240                content: MessageContent::Text("这个函数性能有问题,需要优化算法".to_string()),
241            },
242            Message {
243                role: Role::Assistant,
244                content: MessageContent::Text("好的,我来优化这个函数:\n```rust\nfn optimize() {}\n```".to_string()),
245            },
246            Message {
247                role: Role::User,
248                content: MessageContent::Text("测试失败了,出现错误".to_string()),
249            },
250        ];
251
252        let level = analyzer.analyze_complexity(&messages);
253        assert_eq!(level, ComplexityLevel::High);
254    }
255
256    #[test]
257    fn test_medium_complexity() {
258        // Use analyzer with very low thresholds for testing
259        let config = ComplexityConfig {
260            high_threshold: 0.5,
261            medium_threshold: 0.05,  // Very low threshold
262            ..Default::default()
263        };
264        let analyzer = ComplexityAnalyzer::new(config);
265
266        let messages = vec![
267            Message {
268                role: Role::User,
269                content: MessageContent::Text("如何在数据库中查询数据?".to_string()),
270            },
271            Message {
272                role: Role::Assistant,
273                content: MessageContent::Text("你可以使用 SQL 查询".to_string()),
274            },
275        ];
276
277        let level = analyzer.analyze_complexity(&messages);
278        assert_eq!(level, ComplexityLevel::Medium);
279    }
280
281    #[test]
282    fn test_low_complexity() {
283        let messages = vec![
284            Message {
285                role: Role::User,
286                content: MessageContent::Text("你好".to_string()),
287            },
288            Message {
289                role: Role::Assistant,
290                content: MessageContent::Text("你好!有什么可以帮助你的?".to_string()),
291            },
292        ];
293        
294        let level = ComplexityAnalyzer::analyze(&messages);
295        assert_eq!(level, ComplexityLevel::Low);
296    }
297}