matrixcode_core/compress/
complexity.rs1use crate::providers::{Message, MessageContent};
6
7#[derive(Debug, Clone, Copy, PartialEq, Eq)]
9pub enum ComplexityLevel {
10 High,
12 Medium,
14 Low,
16}
17
18#[derive(Debug, Clone)]
20pub struct ComplexityConfig {
21 code_weight: f32,
23 tool_weight: f32,
25 keyword_weight: f32,
27 error_weight: f32,
29 high_threshold: f32,
31 medium_threshold: f32,
33}
34
35impl Default for ComplexityConfig {
36 fn default() -> Self {
37 Self {
38 code_weight: 0.3,
39 tool_weight: 0.25,
40 keyword_weight: 0.15,
41 error_weight: 0.2,
42 high_threshold: 5.0,
43 medium_threshold: 2.0,
44 }
45 }
46}
47
48pub struct ComplexityAnalyzer {
50 config: ComplexityConfig,
51 tech_keywords: Vec<String>,
53}
54
55impl Default for ComplexityAnalyzer {
56 fn default() -> Self {
57 Self::new(ComplexityConfig::default())
58 }
59}
60
61impl ComplexityAnalyzer {
62 pub fn new(config: ComplexityConfig) -> Self {
63 Self {
64 config,
65 tech_keywords: vec![
66 "函数".to_string(),
68 "优化".to_string(),
69 "性能".to_string(),
70 "错误".to_string(),
71 "测试".to_string(),
72 "架构".to_string(),
73 "数据库".to_string(),
74 "算法".to_string(),
75 "重构".to_string(),
76 "调试".to_string(),
77 "部署".to_string(),
78 "缓存".to_string(),
79 "并发".to_string(),
80 "异步".to_string(),
81 "function".to_string(),
83 "optimize".to_string(),
84 "performance".to_string(),
85 "error".to_string(),
86 "test".to_string(),
87 "architecture".to_string(),
88 "database".to_string(),
89 "algorithm".to_string(),
90 "refactor".to_string(),
91 "debug".to_string(),
92 "deploy".to_string(),
93 "cache".to_string(),
94 "async".to_string(),
95 "concurrent".to_string(),
96 ],
97 }
98 }
99
100 pub fn analyze(messages: &[Message]) -> ComplexityLevel {
102 let analyzer = Self::default();
103 analyzer.analyze_complexity(messages)
104 }
105
106 pub fn analyze_complexity(&self, messages: &[Message]) -> ComplexityLevel {
108 if messages.is_empty() {
109 return ComplexityLevel::Low;
110 }
111
112 let mut score = 0.0;
113
114 let code_count = messages.iter()
116 .filter(|m| self.has_code(m))
117 .count();
118 score += code_count as f32 * self.config.code_weight;
119
120 let tool_count = messages.iter()
122 .filter(|m| self.has_tool_use(m))
123 .count();
124 score += tool_count as f32 * self.config.tool_weight;
125
126 let keyword_hits = messages.iter()
128 .map(|m| self.count_keywords(m))
129 .sum::<usize>();
130 score += keyword_hits as f32 * self.config.keyword_weight;
131
132 let error_count = messages.iter()
134 .filter(|m| self.has_error(m))
135 .count();
136 score += error_count as f32 * self.config.error_weight;
137
138 score /= messages.len() as f32;
140
141 if score >= self.config.high_threshold {
143 ComplexityLevel::High
144 } else if score >= self.config.medium_threshold {
145 ComplexityLevel::Medium
146 } else {
147 ComplexityLevel::Low
148 }
149 }
150
151 fn has_code(&self, message: &Message) -> bool {
153 let content = self.get_text_content(message);
154 content.contains("```") ||
155 content.contains("fn ") ||
156 content.contains("function ") ||
157 content.contains("class ") ||
158 content.contains("struct ")
159 }
160
161 fn has_tool_use(&self, message: &Message) -> bool {
163 matches!(message.content, MessageContent::Blocks(_)) ||
164 self.get_text_content(message).contains("tool") ||
165 self.get_text_content(message).contains("工具")
166 }
167
168 fn has_error(&self, message: &Message) -> bool {
170 let content = self.get_text_content(message);
171 content.contains("error") ||
172 content.contains("failed") ||
173 content.contains("错误") ||
174 content.contains("失败") ||
175 content.contains("异常") ||
176 content.contains("exception")
177 }
178
179 fn count_keywords(&self, message: &Message) -> usize {
181 let content = self.get_text_content(message).to_lowercase();
182 self.tech_keywords.iter()
183 .filter(|kw| content.contains(&kw.to_lowercase()))
184 .count()
185 }
186
187 fn get_text_content(&self, message: &Message) -> String {
189 match &message.content {
190 MessageContent::Text(text) => text.clone(),
191 MessageContent::Blocks(blocks) => {
192 blocks.iter()
193 .filter_map(|block| {
194 if let crate::providers::ContentBlock::Text { text } = block {
195 Some(text.clone())
196 } else {
197 None
198 }
199 })
200 .collect::<Vec<_>>()
201 .join("\n")
202 }
203 }
204 }
205
206 pub fn complexity_description(level: ComplexityLevel) -> &'static str {
208 match level {
209 ComplexityLevel::High => "技术讨论密集:大量代码、工具使用、错误处理",
210 ComplexityLevel::Medium => "混合对话:部分技术内容",
211 ComplexityLevel::Low => "简单对话:少量技术内容",
212 }
213 }
214}
215
216#[cfg(test)]
217mod tests {
218 use super::*;
219 use crate::providers::Role;
220
221 #[test]
222 fn test_empty_messages() {
223 let level = ComplexityAnalyzer::analyze(&[]);
224 assert_eq!(level, ComplexityLevel::Low);
225 }
226
227 #[test]
228 fn test_high_complexity() {
229 let messages = vec![
230 Message {
231 role: Role::User,
232 content: MessageContent::Text("这个函数性能有问题,需要优化算法".to_string()),
233 },
234 Message {
235 role: Role::Assistant,
236 content: MessageContent::Text("好的,我来优化这个函数:\n```rust\nfn optimize() {}\n```".to_string()),
237 },
238 Message {
239 role: Role::User,
240 content: MessageContent::Text("测试失败了,出现错误".to_string()),
241 },
242 ];
243
244 let level = ComplexityAnalyzer::analyze(&messages);
245 assert_eq!(level, ComplexityLevel::High);
246 }
247
248 #[test]
249 fn test_medium_complexity() {
250 let messages = vec![
251 Message {
252 role: Role::User,
253 content: MessageContent::Text("如何在数据库中查询数据?".to_string()),
254 },
255 Message {
256 role: Role::Assistant,
257 content: MessageContent::Text("你可以使用 SQL 查询".to_string()),
258 },
259 ];
260
261 let level = ComplexityAnalyzer::analyze(&messages);
262 assert_eq!(level, ComplexityLevel::Medium);
263 }
264
265 #[test]
266 fn test_low_complexity() {
267 let messages = vec![
268 Message {
269 role: Role::User,
270 content: MessageContent::Text("你好".to_string()),
271 },
272 Message {
273 role: Role::Assistant,
274 content: MessageContent::Text("你好!有什么可以帮助你的?".to_string()),
275 },
276 ];
277
278 let level = ComplexityAnalyzer::analyze(&messages);
279 assert_eq!(level, ComplexityLevel::Low);
280 }
281}