1use crate::agent::Agent;
32use crate::config::Config;
33use crate::error::{HeliosError, Result};
34use crate::tools::Tool;
35use serde::{Deserialize, Serialize};
36use std::collections::HashMap;
37
38#[derive(Debug, Clone, Serialize, Deserialize)]
40pub struct AgentConfig {
41 pub name: String,
43 pub system_prompt: String,
45 #[serde(default)]
47 pub tool_indices: Vec<usize>,
48 pub role: String,
50}
51
52#[derive(Debug, Deserialize)]
54struct OrchestrationPlanJson {
55 num_agents: usize,
56 reasoning: String,
57 agents: Vec<AgentConfig>,
58 task_breakdown: HashMap<String, String>,
59}
60
61pub struct SpawnedAgent {
63 pub agent: Agent,
65 pub config: AgentConfig,
67 pub result: Option<String>,
69}
70
71#[derive(Debug, Clone, Serialize, Deserialize)]
73pub struct OrchestrationPlan {
74 pub task: String,
76 pub num_agents: usize,
78 pub reasoning: String,
80 pub agents: Vec<AgentConfig>,
82 pub task_breakdown: HashMap<String, String>,
84}
85
86pub struct AutoForest {
88 config: Config,
89 tools: Vec<Box<dyn Tool>>,
90 spawned_agents: Vec<SpawnedAgent>,
91 orchestration_plan: Option<OrchestrationPlan>,
92 orchestrator_agent: Option<Agent>,
93}
94
95impl AutoForest {
96 #[allow(clippy::new_ret_no_self)]
98 pub fn new(config: Config) -> AutoForestBuilder {
99 AutoForestBuilder::new(config)
100 }
101
102 pub fn orchestration_plan(&self) -> Option<&OrchestrationPlan> {
104 self.orchestration_plan.as_ref()
105 }
106
107 pub fn spawned_agents(&self) -> &[SpawnedAgent] {
109 &self.spawned_agents
110 }
111
112 async fn generate_orchestration_plan(&mut self, task: &str) -> Result<OrchestrationPlan> {
114 let tools_info = self
116 .tools
117 .iter()
118 .enumerate()
119 .map(|(i, tool)| format!("- Tool {}: {} ({})", i, tool.name(), tool.description()))
120 .collect::<Vec<_>>()
121 .join("\n");
122
123 let orchestrator_prompt = format!(
124 r#"You are an expert task orchestrator. Your job is to analyze a task and create an optimal plan for a forest of AI agents to complete it.
125
126Available tools:
127{}
128
129Given the task, you must:
1301. Determine the optimal number of agents (1-5)
1312. Define each agent's role and specialization
1323. Create specialized system prompts for each agent
1334. Assign tools to each agent based on their role
1345. Break down the task into subtasks for each agent
135
136Respond with ONLY a JSON object with this structure (no markdown, no extra text):
137{{
138 "num_agents": <number>,
139 "reasoning": "<brief explanation>",
140 "agents": [
141 {{
142 "name": "<agent_name>",
143 "role": "<role>",
144 "system_prompt": "<specialized_prompt>",
145 "tool_indices": [<indices>]
146 }}
147 ],
148 "task_breakdown": {{
149 "<agent_name>": "<specific_task_for_this_agent>"
150 }}
151}}"#,
152 tools_info
153 );
154
155 if self.orchestrator_agent.is_none() {
157 let orchestrator = Agent::builder("Orchestrator")
158 .config(self.config.clone())
159 .system_prompt(&orchestrator_prompt)
160 .build()
161 .await?;
162 self.orchestrator_agent = Some(orchestrator);
163 }
164
165 let orchestrator = self.orchestrator_agent.as_mut().ok_or_else(|| {
167 HeliosError::AgentError("Failed to create orchestrator agent".to_string())
168 })?;
169
170 let response = orchestrator.chat(&format!("Task: {}", task)).await?;
172
173 let plan_data: OrchestrationPlanJson = serde_json::from_str(&response).map_err(|e| {
175 HeliosError::AgentError(format!("Failed to parse orchestration plan: {}", e))
176 })?;
177
178 let plan = OrchestrationPlan {
180 task: task.to_string(),
181 num_agents: plan_data.num_agents,
182 reasoning: plan_data.reasoning,
183 agents: plan_data.agents,
184 task_breakdown: plan_data.task_breakdown,
185 };
186
187 self.orchestration_plan = Some(plan.clone());
188 Ok(plan)
189 }
190
191 async fn spawn_agents_from_plan(&mut self, plan: &OrchestrationPlan) -> Result<()> {
193 self.spawned_agents.clear();
194
195 for agent_config in &plan.agents {
196 let agent = Agent::builder(&agent_config.name)
198 .config(self.config.clone())
199 .system_prompt(&agent_config.system_prompt)
200 .build()
201 .await?;
202
203 let spawned = SpawnedAgent {
207 agent,
208 config: agent_config.clone(),
209 result: None,
210 };
211
212 self.spawned_agents.push(spawned);
213 }
214
215 Ok(())
216 }
217
218 pub async fn execute_task(&mut self, task: &str) -> Result<String> {
220 let plan = self.generate_orchestration_plan(task).await?;
222
223 self.spawn_agents_from_plan(&plan).await?;
225
226 let mut futures = Vec::new();
228
229 for spawned_agent in self.spawned_agents.drain(..) {
230 let agent_task = plan
231 .task_breakdown
232 .get(&spawned_agent.config.name)
233 .cloned()
234 .unwrap_or_else(|| format!("Complete your assigned portion of: {}", task));
235
236 let future = async move {
237 let mut agent = spawned_agent.agent;
238 let config = spawned_agent.config;
239 let result = agent.chat(&agent_task).await;
240 (agent, config, result)
241 };
242
243 futures.push(future);
244 }
245
246 let completed_agents = futures::future::join_all(futures).await;
248
249 let mut results = HashMap::new();
251 self.spawned_agents.clear(); for (agent, config, result) in completed_agents {
254 let agent_name = config.name.clone();
255 let (result_string, result_for_map) = match result {
256 Ok(output) => (Some(output.clone()), output),
257 Err(e) => {
258 let err_msg = format!("Error: {}", e);
259 (Some(err_msg.clone()), err_msg)
260 }
261 };
262 results.insert(agent_name, result_for_map);
263
264 self.spawned_agents.push(SpawnedAgent {
265 agent,
266 config,
267 result: result_string,
268 });
269 }
270
271 let aggregated_result = self.aggregate_results(&results, task).await?;
273
274 Ok(aggregated_result)
275 }
276
277 pub async fn do_task(&mut self, task: &str) -> Result<String> {
279 self.execute_task(task).await
280 }
281
282 pub async fn run(&mut self, task: &str) -> Result<String> {
284 self.execute_task(task).await
285 }
286
287 async fn aggregate_results(
289 &mut self,
290 results: &HashMap<String, String>,
291 task: &str,
292 ) -> Result<String> {
293 let mut result_text = String::new();
294 result_text.push_str("## Task Execution Summary\n\n");
295 result_text.push_str(&format!("**Task**: {}\n\n", task));
296 result_text.push_str("### Agent Results:\n\n");
297
298 for (agent_name, result) in results {
299 result_text.push_str(&format!("**{}**:\n{}\n\n", agent_name, result));
300 }
301
302 if results.len() > 1 {
304 result_text.push_str("### Synthesized Analysis:\n\n");
305 let orchestrator = self
306 .orchestrator_agent
307 .as_mut()
308 .ok_or_else(|| HeliosError::AgentError("Orchestrator not available".to_string()))?;
309
310 let synthesis_prompt = format!(
311 "Synthesize these agent results into a cohesive answer:\n{}",
312 result_text
313 );
314 let synthesis = orchestrator.chat(&synthesis_prompt).await?;
315 result_text.push_str(&synthesis);
316 }
317
318 Ok(result_text)
319 }
320}
321
322pub struct AutoForestBuilder {
324 config: Config,
325 tools: Vec<Box<dyn Tool>>,
326}
327
328impl AutoForestBuilder {
329 pub fn new(config: Config) -> Self {
331 Self {
332 config,
333 tools: Vec::new(),
334 }
335 }
336
337 pub fn with_tools(mut self, tools: Vec<Box<dyn Tool>>) -> Self {
339 self.tools = tools;
340 self
341 }
342
343 pub async fn build(self) -> Result<AutoForest> {
345 Ok(AutoForest {
346 config: self.config,
347 tools: self.tools,
348 spawned_agents: Vec::new(),
349 orchestration_plan: None,
350 orchestrator_agent: None,
351 })
352 }
353}
354
355#[cfg(test)]
356mod tests {
357 use super::*;
358
359 #[test]
360 fn test_agent_config_creation() {
361 let config = AgentConfig {
362 name: "TestAgent".to_string(),
363 system_prompt: "You are helpful".to_string(),
364 tool_indices: vec![0, 1],
365 role: "Analyzer".to_string(),
366 };
367
368 assert_eq!(config.name, "TestAgent");
369 assert_eq!(config.tool_indices.len(), 2);
370 }
371
372 #[test]
373 fn test_orchestration_plan_creation() {
374 let plan = OrchestrationPlan {
375 task: "Test task".to_string(),
376 num_agents: 2,
377 reasoning: "Two agents needed".to_string(),
378 agents: vec![],
379 task_breakdown: HashMap::new(),
380 };
381
382 assert_eq!(plan.num_agents, 2);
383 assert_eq!(plan.task, "Test task");
384 }
385}