Skip to main content

aster/agents/subagent_scheduler/
strategy.rs

1//! 调度策略模块
2//!
3//! 根据任务特征自动选择最优调度策略
4
5use serde::{Deserialize, Serialize};
6
7use super::types::SubAgentTask;
8
9/// 调度策略
10#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
11#[serde(rename_all = "camelCase")]
12pub enum SchedulingStrategy {
13    /// 单 Agent 直接执行(简单任务)
14    SingleAgent,
15    /// 串行执行(有依赖的任务)
16    Sequential,
17    /// 并行执行(独立任务)
18    Parallel,
19    /// 广度优先并行(研究任务)
20    BreadthFirst,
21    /// 自适应(根据任务特征自动选择)
22    #[default]
23    Adaptive,
24}
25
26/// 任务复杂度
27#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
28#[serde(rename_all = "camelCase")]
29pub enum TaskComplexity {
30    /// 简单任务(单次 LLM 调用)
31    Simple,
32    /// 中等任务(需要多步骤)
33    Medium,
34    /// 复杂任务(需要多个子任务)
35    Complex,
36    /// 研究任务(需要广泛探索)
37    Research,
38}
39
40/// 策略选择器
41pub struct StrategySelector;
42
43impl StrategySelector {
44    /// 根据任务列表自动选择策略
45    pub fn select(tasks: &[SubAgentTask]) -> SchedulingStrategy {
46        if tasks.is_empty() {
47            return SchedulingStrategy::SingleAgent;
48        }
49
50        if tasks.len() == 1 {
51            return SchedulingStrategy::SingleAgent;
52        }
53
54        // 检查是否有依赖关系
55        let has_dependencies = tasks.iter().any(|t| t.has_dependencies());
56
57        // 检查任务类型分布
58        let task_types: Vec<&str> = tasks.iter().map(|t| t.task_type.as_str()).collect();
59        let is_research = task_types
60            .iter()
61            .any(|t| matches!(*t, "research" | "explore" | "search" | "analyze"));
62
63        // 检查是否高度可并行化
64        let parallelizable_ratio = if has_dependencies {
65            let independent_count = tasks.iter().filter(|t| !t.has_dependencies()).count();
66            independent_count as f64 / tasks.len() as f64
67        } else {
68            1.0
69        };
70
71        // 策略选择逻辑
72        if is_research && parallelizable_ratio > 0.7 {
73            SchedulingStrategy::BreadthFirst
74        } else if has_dependencies && parallelizable_ratio < 0.3 {
75            SchedulingStrategy::Sequential
76        } else if parallelizable_ratio > 0.5 {
77            SchedulingStrategy::Parallel
78        } else {
79            SchedulingStrategy::Sequential
80        }
81    }
82
83    /// 估算任务复杂度
84    pub fn estimate_complexity(task: &SubAgentTask) -> TaskComplexity {
85        // 基于任务类型估算
86        let type_complexity = match task.task_type.as_str() {
87            "explore" | "search" => TaskComplexity::Simple,
88            "analyze" | "review" => TaskComplexity::Medium,
89            "code" | "implement" => TaskComplexity::Complex,
90            "research" | "investigate" => TaskComplexity::Research,
91            _ => TaskComplexity::Medium,
92        };
93
94        // 基于 prompt 长度调整
95        let prompt_len = task.prompt.len();
96        if prompt_len > 1000 {
97            match type_complexity {
98                TaskComplexity::Simple => TaskComplexity::Medium,
99                TaskComplexity::Medium => TaskComplexity::Complex,
100                _ => type_complexity,
101            }
102        } else {
103            type_complexity
104        }
105    }
106
107    /// 根据复杂度推荐并发数
108    pub fn recommended_concurrency(complexity: TaskComplexity) -> usize {
109        match complexity {
110            TaskComplexity::Simple => 10,
111            TaskComplexity::Medium => 5,
112            TaskComplexity::Complex => 3,
113            TaskComplexity::Research => 8,
114        }
115    }
116
117    /// 根据复杂度推荐模型
118    pub fn recommended_model(complexity: TaskComplexity) -> &'static str {
119        match complexity {
120            TaskComplexity::Simple => "haiku",
121            TaskComplexity::Medium => "sonnet",
122            TaskComplexity::Complex => "opus",
123            TaskComplexity::Research => "sonnet",
124        }
125    }
126}
127
128#[cfg(test)]
129mod tests {
130    use super::*;
131
132    #[test]
133    fn test_select_single_task() {
134        let tasks = vec![SubAgentTask::new("t1", "explore", "分析")];
135        assert_eq!(
136            StrategySelector::select(&tasks),
137            SchedulingStrategy::SingleAgent
138        );
139    }
140
141    #[test]
142    fn test_select_parallel_tasks() {
143        // 使用 code 类型,避免被识别为研究任务
144        let tasks = vec![
145            SubAgentTask::new("t1", "code", "实现1"),
146            SubAgentTask::new("t2", "code", "实现2"),
147            SubAgentTask::new("t3", "code", "实现3"),
148        ];
149        assert_eq!(
150            StrategySelector::select(&tasks),
151            SchedulingStrategy::Parallel
152        );
153    }
154
155    #[test]
156    fn test_select_sequential_tasks() {
157        let tasks = vec![
158            SubAgentTask::new("t1", "code", "实现1"),
159            SubAgentTask::new("t2", "code", "实现2").with_dependencies(vec!["t1"]),
160            SubAgentTask::new("t3", "code", "实现3").with_dependencies(vec!["t2"]),
161        ];
162        assert_eq!(
163            StrategySelector::select(&tasks),
164            SchedulingStrategy::Sequential
165        );
166    }
167
168    #[test]
169    fn test_select_research_tasks() {
170        let tasks = vec![
171            SubAgentTask::new("t1", "research", "研究1"),
172            SubAgentTask::new("t2", "research", "研究2"),
173            SubAgentTask::new("t3", "research", "研究3"),
174        ];
175        assert_eq!(
176            StrategySelector::select(&tasks),
177            SchedulingStrategy::BreadthFirst
178        );
179    }
180
181    #[test]
182    fn test_estimate_complexity() {
183        let simple = SubAgentTask::new("t1", "explore", "简单任务");
184        let complex = SubAgentTask::new("t2", "code", "复杂任务");
185
186        assert_eq!(
187            StrategySelector::estimate_complexity(&simple),
188            TaskComplexity::Simple
189        );
190        assert_eq!(
191            StrategySelector::estimate_complexity(&complex),
192            TaskComplexity::Complex
193        );
194    }
195
196    #[test]
197    fn test_recommended_model() {
198        assert_eq!(
199            StrategySelector::recommended_model(TaskComplexity::Simple),
200            "haiku"
201        );
202        assert_eq!(
203            StrategySelector::recommended_model(TaskComplexity::Complex),
204            "opus"
205        );
206    }
207}