Skip to main content

aster/agents/subagent_scheduler/
summary.rs

1//! 摘要生成模块
2//!
3//! 为 SubAgent 结果生成精炼摘要,减少返回给父 Agent 的 token 数
4
5use super::types::{SubAgentResult, TokenUsage};
6
7/// 摘要生成器
8pub struct SummaryGenerator {
9    /// 最大摘要 token 数
10    max_tokens: usize,
11}
12
13impl Default for SummaryGenerator {
14    fn default() -> Self {
15        Self::new(2000)
16    }
17}
18
19impl SummaryGenerator {
20    /// 创建摘要生成器
21    pub fn new(max_tokens: usize) -> Self {
22        Self { max_tokens }
23    }
24
25    /// 为单个结果生成摘要
26    pub fn summarize_result(&self, result: &SubAgentResult) -> String {
27        if let Some(summary) = &result.summary {
28            return self.truncate_to_tokens(summary, self.max_tokens);
29        }
30
31        if let Some(output) = &result.output {
32            return self.create_summary_from_output(output, result);
33        }
34
35        if let Some(error) = &result.error {
36            return format!("任务 {} 失败: {}", result.task_id, error);
37        }
38
39        format!("任务 {} 完成,无输出", result.task_id)
40    }
41
42    /// 合并多个结果的摘要
43    pub fn merge_summaries(&self, results: &[SubAgentResult]) -> String {
44        let mut sections = Vec::new();
45        let mut total_tokens = 0;
46        let tokens_per_result = self.max_tokens / results.len().max(1);
47
48        for result in results {
49            let summary = self.summarize_result(result);
50            let truncated = self.truncate_to_tokens(&summary, tokens_per_result);
51
52            let section = if result.success {
53                format!("✅ {}: {}", result.task_id, truncated)
54            } else {
55                format!("❌ {}: {}", result.task_id, truncated)
56            };
57
58            total_tokens += self.estimate_tokens(&section);
59            if total_tokens > self.max_tokens {
60                sections.push("... (更多结果已省略)".to_string());
61                break;
62            }
63
64            sections.push(section);
65        }
66
67        // 添加统计信息
68        let success_count = results.iter().filter(|r| r.success).count();
69        let fail_count = results.len() - success_count;
70        let total_duration: u64 = results.iter().map(|r| r.duration.as_millis() as u64).sum();
71
72        let stats = format!(
73            "\n---\n📊 统计: {} 成功, {} 失败, 总耗时 {:.2}s",
74            success_count,
75            fail_count,
76            total_duration as f64 / 1000.0
77        );
78
79        format!("{}\n{}", sections.join("\n\n"), stats)
80    }
81
82    /// 从输出创建摘要
83    fn create_summary_from_output(&self, output: &str, result: &SubAgentResult) -> String {
84        let status = if result.success { "成功" } else { "失败" };
85        let duration = result.duration.as_secs_f64();
86
87        // 提取关键信息
88        let key_points = self.extract_key_points(output);
89
90        let mut summary = format!(
91            "任务 {} {} (耗时 {:.2}s)\n",
92            result.task_id, status, duration
93        );
94
95        if !key_points.is_empty() {
96            summary.push_str("关键发现:\n");
97            for point in key_points.iter().take(5) {
98                summary.push_str(&format!("- {}\n", point));
99            }
100        }
101
102        self.truncate_to_tokens(&summary, self.max_tokens)
103    }
104
105    /// 提取关键点
106    fn extract_key_points(&self, text: &str) -> Vec<String> {
107        let mut points = Vec::new();
108
109        // 提取以特定标记开头的行
110        for line in text.lines() {
111            let trimmed = line.trim();
112            if trimmed.starts_with("- ")
113                || trimmed.starts_with("* ")
114                || trimmed.starts_with("• ")
115                || trimmed.starts_with("✓ ")
116                || trimmed.starts_with("✅ ")
117            {
118                points.push(trimmed.chars().skip(2).collect());
119            } else if trimmed.starts_with("1.")
120                || trimmed.starts_with("2.")
121                || trimmed.starts_with("3.")
122            {
123                if let Some(content) = trimmed.split_once('.') {
124                    points.push(content.1.trim().to_string());
125                }
126            }
127        }
128
129        // 如果没有找到列表项,提取首尾段落
130        if points.is_empty() {
131            let paragraphs: Vec<&str> = text
132                .split("\n\n")
133                .filter(|p| !p.trim().is_empty())
134                .collect();
135
136            if let Some(first) = paragraphs.first() {
137                points.push(self.truncate_text(first, 200));
138            }
139            if paragraphs.len() > 1 {
140                if let Some(last) = paragraphs.last() {
141                    points.push(self.truncate_text(last, 200));
142                }
143            }
144        }
145
146        points
147    }
148
149    /// 截断文本到指定字符数
150    fn truncate_text(&self, text: &str, max_chars: usize) -> String {
151        if text.chars().count() <= max_chars {
152            text.to_string()
153        } else {
154            let truncated: String = text.chars().take(max_chars - 3).collect();
155            format!("{}...", truncated)
156        }
157    }
158
159    /// 截断到指定 token 数
160    fn truncate_to_tokens(&self, text: &str, max_tokens: usize) -> String {
161        let estimated = self.estimate_tokens(text);
162        if estimated <= max_tokens {
163            return text.to_string();
164        }
165
166        // 粗略估算:4 字符 ≈ 1 token
167        let max_chars = max_tokens * 4;
168        self.truncate_text(text, max_chars)
169    }
170
171    /// 估算 token 数(粗略)
172    fn estimate_tokens(&self, text: &str) -> usize {
173        // 简单估算:4 字符 ≈ 1 token
174        text.len() / 4
175    }
176}
177
178/// 计算总 token 使用量
179pub fn calculate_total_token_usage(results: &[SubAgentResult]) -> TokenUsage {
180    let mut total = TokenUsage::default();
181
182    for result in results {
183        if let Some(usage) = &result.token_usage {
184            total.input_tokens += usage.input_tokens;
185            total.output_tokens += usage.output_tokens;
186            total.total_tokens += usage.total_tokens;
187        }
188    }
189
190    total
191}
192
193#[cfg(test)]
194mod tests {
195    use super::*;
196    use chrono::Utc;
197    use std::collections::HashMap;
198    use std::time::Duration;
199
200    fn create_test_result(task_id: &str, success: bool, output: Option<&str>) -> SubAgentResult {
201        SubAgentResult {
202            task_id: task_id.to_string(),
203            success,
204            output: output.map(|s| s.to_string()),
205            summary: None,
206            error: if success {
207                None
208            } else {
209                Some("测试错误".to_string())
210            },
211            duration: Duration::from_secs(1),
212            retries: 0,
213            started_at: Utc::now(),
214            completed_at: Utc::now(),
215            token_usage: Some(TokenUsage {
216                input_tokens: 100,
217                output_tokens: 50,
218                total_tokens: 150,
219            }),
220            metadata: HashMap::new(),
221        }
222    }
223
224    #[test]
225    fn test_summarize_success_result() {
226        let generator = SummaryGenerator::new(1000);
227        let result = create_test_result("task-1", true, Some("任务完成"));
228
229        let summary = generator.summarize_result(&result);
230        assert!(summary.contains("task-1"));
231        assert!(summary.contains("成功"));
232    }
233
234    #[test]
235    fn test_summarize_failed_result() {
236        let generator = SummaryGenerator::new(1000);
237        let result = create_test_result("task-1", false, None);
238
239        let summary = generator.summarize_result(&result);
240        assert!(summary.contains("task-1"));
241        assert!(summary.contains("失败"));
242    }
243
244    #[test]
245    fn test_merge_summaries() {
246        let generator = SummaryGenerator::new(2000);
247        let results = vec![
248            create_test_result("task-1", true, Some("结果1")),
249            create_test_result("task-2", true, Some("结果2")),
250            create_test_result("task-3", false, None),
251        ];
252
253        let merged = generator.merge_summaries(&results);
254        assert!(merged.contains("task-1"));
255        assert!(merged.contains("task-2"));
256        assert!(merged.contains("task-3"));
257        assert!(merged.contains("2 成功"));
258        assert!(merged.contains("1 失败"));
259    }
260
261    #[test]
262    fn test_extract_key_points() {
263        let generator = SummaryGenerator::new(1000);
264        let text = "概述\n- 发现1\n- 发现2\n* 发现3";
265
266        let points = generator.extract_key_points(text);
267        assert_eq!(points.len(), 3);
268        assert!(points.contains(&"发现1".to_string()));
269    }
270
271    #[test]
272    fn test_calculate_total_token_usage() {
273        let results = vec![
274            create_test_result("task-1", true, None),
275            create_test_result("task-2", true, None),
276        ];
277
278        let total = calculate_total_token_usage(&results);
279        assert_eq!(total.input_tokens, 200);
280        assert_eq!(total.output_tokens, 100);
281        assert_eq!(total.total_tokens, 300);
282    }
283}