1use std::{collections::HashMap, sync::Arc};
2
3use tokio::sync::Mutex;
4
5use agentrs_core::{
6 Agent as AgentTrait, AgentError, AgentOutput, CompletionRequest, LlmProvider, Message, Result,
7};
8
9use crate::{AgentGraph, EventBus, OrchestratorEvent, SharedConversation};
10
11type SharedAgent = Arc<Mutex<Box<dyn AgentTrait>>>;
12
13#[derive(Clone)]
15pub enum RoutingStrategy {
16 Sequential(Vec<String>),
18 Parallel(Vec<String>),
20 Supervisor {
22 llm: Arc<dyn LlmProvider>,
24 agents: Vec<String>,
26 max_turns: usize,
28 },
29 Graph(AgentGraph),
31}
32
33impl Default for RoutingStrategy {
34 fn default() -> Self {
35 Self::Sequential(Vec::new())
36 }
37}
38
39pub struct MultiAgentOrchestratorBuilder {
41 agents: HashMap<String, SharedAgent>,
42 order: Vec<String>,
43 routing: Option<RoutingStrategy>,
44 shared_memory: Option<SharedConversation>,
45 event_bus: Option<Arc<dyn EventBus>>,
46}
47
48pub struct MultiAgentOrchestrator {
50 agents: HashMap<String, SharedAgent>,
51 routing: RoutingStrategy,
52 shared_memory: Option<SharedConversation>,
53 event_bus: Option<Arc<dyn EventBus>>,
54}
55
56impl MultiAgentOrchestrator {
57 pub fn builder() -> MultiAgentOrchestratorBuilder {
59 MultiAgentOrchestratorBuilder {
60 agents: HashMap::new(),
61 order: Vec::new(),
62 routing: None,
63 shared_memory: None,
64 event_bus: None,
65 }
66 }
67
68 pub async fn run(&mut self, input: &str) -> Result<AgentOutput> {
70 match self.routing.clone() {
71 RoutingStrategy::Sequential(names) => self.run_sequential(input, names).await,
72 RoutingStrategy::Parallel(names) => self.run_parallel(input, names).await,
73 RoutingStrategy::Supervisor {
74 llm,
75 agents,
76 max_turns,
77 } => self.run_supervisor(input, llm, agents, max_turns).await,
78 RoutingStrategy::Graph(graph) => self.run_graph(input, &graph).await,
79 }
80 }
81
82 pub fn add_agent_boxed(&mut self, name: impl Into<String>, agent: Box<dyn AgentTrait>) {
84 self.agents.insert(name.into(), Arc::new(Mutex::new(agent)));
85 }
86
87 async fn run_sequential(&mut self, input: &str, names: Vec<String>) -> Result<AgentOutput> {
88 let mut current_input = input.to_string();
89 let mut final_output = None;
90
91 for name in names {
92 let output = self.run_named_agent(&name, ¤t_input).await?;
93 current_input = output.text.clone();
94 final_output = Some(output);
95 }
96
97 final_output
98 .ok_or_else(|| AgentError::InvalidConfiguration("no agents configured".to_string()))
99 }
100
101 async fn run_parallel(&mut self, input: &str, names: Vec<String>) -> Result<AgentOutput> {
102 let futures = names.into_iter().map(|name| {
103 let agent = self
104 .agents
105 .get(&name)
106 .cloned()
107 .ok_or_else(|| AgentError::AgentNotFound(name.clone()));
108 let event_bus = self.event_bus.clone();
109 let shared_memory = self.shared_memory.clone();
110 let input = input.to_string();
111 async move {
112 let agent = agent?;
113 let mut agent = agent.lock().await;
114 let output = agent.run(&input).await?;
115 if let Some(shared_memory) = shared_memory {
116 if let Some(last_message) = output.messages.last().cloned() {
117 shared_memory.add(&name, last_message).await?;
118 }
119 }
120 if let Some(event_bus) = event_bus {
121 event_bus
122 .publish(OrchestratorEvent::AgentCompleted {
123 agent: name.clone(),
124 output: output.clone(),
125 })
126 .await?;
127 }
128 Ok::<_, AgentError>((name, output))
129 }
130 });
131
132 let results = futures::future::try_join_all(futures).await?;
133 let mut messages = Vec::new();
134 let mut text = String::new();
135 for (name, output) in results {
136 if !text.is_empty() {
137 text.push_str("\n\n");
138 }
139 text.push_str(&format!("[{name}]\n{}", output.text));
140 messages.extend(output.messages);
141 }
142
143 Ok(AgentOutput {
144 text,
145 steps: 1,
146 usage: Default::default(),
147 messages,
148 metadata: HashMap::new(),
149 })
150 }
151
152 async fn run_supervisor(
153 &mut self,
154 input: &str,
155 llm: Arc<dyn LlmProvider>,
156 agents: Vec<String>,
157 max_turns: usize,
158 ) -> Result<AgentOutput> {
159 let agent_lines = agents
160 .iter()
161 .map(|name| format!("- {name}"))
162 .collect::<Vec<_>>()
163 .join("\n");
164 let mut context = format!(
165 "You are a supervisor. Available agents:\n{agent_lines}\n\nReturn JSON {{\"agent\": \"name\"}} for the best agent to handle the task. Task: {input}"
166 );
167
168 for _ in 0..max_turns {
169 let response = llm
170 .complete(CompletionRequest {
171 messages: vec![Message::user(context.clone())],
172 tools: None,
173 model: String::new(),
174 temperature: Some(0.0),
175 max_tokens: Some(256),
176 stream: false,
177 system: None,
178 extra: HashMap::new(),
179 })
180 .await?;
181 let choice: serde_json::Value = serde_json::from_str(&response.message.text_content())?;
182 if let Some(agent_name) = choice.get("agent").and_then(serde_json::Value::as_str) {
183 return self.run_named_agent(agent_name, input).await;
184 }
185 context = format!(
186 "{context}\nPrevious response was invalid JSON: {}",
187 response.message.text_content()
188 );
189 }
190
191 Err(AgentError::MaxStepsReached { steps: max_turns })
192 }
193
194 async fn run_graph(&mut self, input: &str, graph: &AgentGraph) -> Result<AgentOutput> {
195 let mut current = graph.entry()?.to_string();
196 let mut current_input = input.to_string();
197
198 loop {
199 let output = self.run_named_agent(¤t, ¤t_input).await?;
200 if let Some(next) = graph.next(¤t, &output) {
201 current = next;
202 current_input = output.text.clone();
203 continue;
204 }
205 return Ok(output);
206 }
207 }
208
209 async fn run_named_agent(&mut self, name: &str, input: &str) -> Result<AgentOutput> {
210 let agent = self
211 .agents
212 .get(name)
213 .cloned()
214 .ok_or_else(|| AgentError::AgentNotFound(name.to_string()))?;
215
216 let mut agent = agent.lock().await;
217 let output = agent.run(input).await?;
218
219 if let Some(shared_memory) = &self.shared_memory {
220 if let Some(last_message) = output.messages.last().cloned() {
221 shared_memory.add(name, last_message).await?;
222 }
223 }
224 if let Some(event_bus) = &self.event_bus {
225 event_bus
226 .publish(OrchestratorEvent::AgentCompleted {
227 agent: name.to_string(),
228 output: output.clone(),
229 })
230 .await?;
231 }
232
233 Ok(output)
234 }
235}
236
237impl MultiAgentOrchestratorBuilder {
238 pub fn add_agent(mut self, name: impl Into<String>, agent: impl AgentTrait + 'static) -> Self {
240 let name = name.into();
241 self.order.push(name.clone());
242 self.agents
243 .insert(name, Arc::new(Mutex::new(Box::new(agent))));
244 self
245 }
246
247 pub fn add_agent_boxed(mut self, name: impl Into<String>, agent: Box<dyn AgentTrait>) -> Self {
249 let name = name.into();
250 self.order.push(name.clone());
251 self.agents.insert(name, Arc::new(Mutex::new(agent)));
252 self
253 }
254
255 pub fn routing(mut self, routing: RoutingStrategy) -> Self {
257 self.routing = Some(routing);
258 self
259 }
260
261 pub fn shared_memory(mut self, shared_memory: SharedConversation) -> Self {
263 self.shared_memory = Some(shared_memory);
264 self
265 }
266
267 pub fn event_bus(mut self, event_bus: Arc<dyn EventBus>) -> Self {
269 self.event_bus = Some(event_bus);
270 self
271 }
272
273 pub fn build(self) -> Result<MultiAgentOrchestrator> {
275 let routing = self
276 .routing
277 .unwrap_or_else(|| RoutingStrategy::Sequential(self.order.clone()));
278 Ok(MultiAgentOrchestrator {
279 agents: self.agents,
280 routing,
281 shared_memory: self.shared_memory,
282 event_bus: self.event_bus,
283 })
284 }
285}