aster/agents/subagent_scheduler/
strategy.rs1use serde::{Deserialize, Serialize};
6
7use super::types::SubAgentTask;
8
9#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
11#[serde(rename_all = "camelCase")]
12pub enum SchedulingStrategy {
13 SingleAgent,
15 Sequential,
17 Parallel,
19 BreadthFirst,
21 #[default]
23 Adaptive,
24}
25
26#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
28#[serde(rename_all = "camelCase")]
29pub enum TaskComplexity {
30 Simple,
32 Medium,
34 Complex,
36 Research,
38}
39
40pub struct StrategySelector;
42
43impl StrategySelector {
44 pub fn select(tasks: &[SubAgentTask]) -> SchedulingStrategy {
46 if tasks.is_empty() {
47 return SchedulingStrategy::SingleAgent;
48 }
49
50 if tasks.len() == 1 {
51 return SchedulingStrategy::SingleAgent;
52 }
53
54 let has_dependencies = tasks.iter().any(|t| t.has_dependencies());
56
57 let task_types: Vec<&str> = tasks.iter().map(|t| t.task_type.as_str()).collect();
59 let is_research = task_types
60 .iter()
61 .any(|t| matches!(*t, "research" | "explore" | "search" | "analyze"));
62
63 let parallelizable_ratio = if has_dependencies {
65 let independent_count = tasks.iter().filter(|t| !t.has_dependencies()).count();
66 independent_count as f64 / tasks.len() as f64
67 } else {
68 1.0
69 };
70
71 if is_research && parallelizable_ratio > 0.7 {
73 SchedulingStrategy::BreadthFirst
74 } else if has_dependencies && parallelizable_ratio < 0.3 {
75 SchedulingStrategy::Sequential
76 } else if parallelizable_ratio > 0.5 {
77 SchedulingStrategy::Parallel
78 } else {
79 SchedulingStrategy::Sequential
80 }
81 }
82
83 pub fn estimate_complexity(task: &SubAgentTask) -> TaskComplexity {
85 let type_complexity = match task.task_type.as_str() {
87 "explore" | "search" => TaskComplexity::Simple,
88 "analyze" | "review" => TaskComplexity::Medium,
89 "code" | "implement" => TaskComplexity::Complex,
90 "research" | "investigate" => TaskComplexity::Research,
91 _ => TaskComplexity::Medium,
92 };
93
94 let prompt_len = task.prompt.len();
96 if prompt_len > 1000 {
97 match type_complexity {
98 TaskComplexity::Simple => TaskComplexity::Medium,
99 TaskComplexity::Medium => TaskComplexity::Complex,
100 _ => type_complexity,
101 }
102 } else {
103 type_complexity
104 }
105 }
106
107 pub fn recommended_concurrency(complexity: TaskComplexity) -> usize {
109 match complexity {
110 TaskComplexity::Simple => 10,
111 TaskComplexity::Medium => 5,
112 TaskComplexity::Complex => 3,
113 TaskComplexity::Research => 8,
114 }
115 }
116
117 pub fn recommended_model(complexity: TaskComplexity) -> &'static str {
119 match complexity {
120 TaskComplexity::Simple => "haiku",
121 TaskComplexity::Medium => "sonnet",
122 TaskComplexity::Complex => "opus",
123 TaskComplexity::Research => "sonnet",
124 }
125 }
126}
127
128#[cfg(test)]
129mod tests {
130 use super::*;
131
132 #[test]
133 fn test_select_single_task() {
134 let tasks = vec![SubAgentTask::new("t1", "explore", "分析")];
135 assert_eq!(
136 StrategySelector::select(&tasks),
137 SchedulingStrategy::SingleAgent
138 );
139 }
140
141 #[test]
142 fn test_select_parallel_tasks() {
143 let tasks = vec![
145 SubAgentTask::new("t1", "code", "实现1"),
146 SubAgentTask::new("t2", "code", "实现2"),
147 SubAgentTask::new("t3", "code", "实现3"),
148 ];
149 assert_eq!(
150 StrategySelector::select(&tasks),
151 SchedulingStrategy::Parallel
152 );
153 }
154
155 #[test]
156 fn test_select_sequential_tasks() {
157 let tasks = vec![
158 SubAgentTask::new("t1", "code", "实现1"),
159 SubAgentTask::new("t2", "code", "实现2").with_dependencies(vec!["t1"]),
160 SubAgentTask::new("t3", "code", "实现3").with_dependencies(vec!["t2"]),
161 ];
162 assert_eq!(
163 StrategySelector::select(&tasks),
164 SchedulingStrategy::Sequential
165 );
166 }
167
168 #[test]
169 fn test_select_research_tasks() {
170 let tasks = vec![
171 SubAgentTask::new("t1", "research", "研究1"),
172 SubAgentTask::new("t2", "research", "研究2"),
173 SubAgentTask::new("t3", "research", "研究3"),
174 ];
175 assert_eq!(
176 StrategySelector::select(&tasks),
177 SchedulingStrategy::BreadthFirst
178 );
179 }
180
181 #[test]
182 fn test_estimate_complexity() {
183 let simple = SubAgentTask::new("t1", "explore", "简单任务");
184 let complex = SubAgentTask::new("t2", "code", "复杂任务");
185
186 assert_eq!(
187 StrategySelector::estimate_complexity(&simple),
188 TaskComplexity::Simple
189 );
190 assert_eq!(
191 StrategySelector::estimate_complexity(&complex),
192 TaskComplexity::Complex
193 );
194 }
195
196 #[test]
197 fn test_recommended_model() {
198 assert_eq!(
199 StrategySelector::recommended_model(TaskComplexity::Simple),
200 "haiku"
201 );
202 assert_eq!(
203 StrategySelector::recommended_model(TaskComplexity::Complex),
204 "opus"
205 );
206 }
207}