Skip to main content

agentrs_multi/
orchestrator.rs

1use std::{collections::HashMap, sync::Arc};
2
3use tokio::sync::Mutex;
4
5use agentrs_core::{
6    Agent as AgentTrait, AgentError, AgentOutput, CompletionRequest, LlmProvider, Message, Result,
7};
8
9use crate::{AgentGraph, EventBus, OrchestratorEvent, SharedConversation};
10
11type SharedAgent = Arc<Mutex<Box<dyn AgentTrait>>>;
12
13/// Routing strategy for multi-agent runs.
14#[derive(Clone)]
15pub enum RoutingStrategy {
16    /// Run agents in the given order.
17    Sequential(Vec<String>),
18    /// Run agents concurrently and merge their outputs.
19    Parallel(Vec<String>),
20    /// Use a routing LLM to decide which agent handles the task.
21    Supervisor {
22        /// Provider used for routing decisions.
23        llm: Arc<dyn LlmProvider>,
24        /// Eligible agent names.
25        agents: Vec<String>,
26        /// Maximum supervisor turns.
27        max_turns: usize,
28    },
29    /// Traverse a graph of agents.
30    Graph(AgentGraph),
31}
32
33impl Default for RoutingStrategy {
34    fn default() -> Self {
35        Self::Sequential(Vec::new())
36    }
37}
38
39/// Builder for [`MultiAgentOrchestrator`].
40pub struct MultiAgentOrchestratorBuilder {
41    agents: HashMap<String, SharedAgent>,
42    order: Vec<String>,
43    routing: Option<RoutingStrategy>,
44    shared_memory: Option<SharedConversation>,
45    event_bus: Option<Arc<dyn EventBus>>,
46}
47
48/// Multi-agent orchestrator.
49pub struct MultiAgentOrchestrator {
50    agents: HashMap<String, SharedAgent>,
51    routing: RoutingStrategy,
52    shared_memory: Option<SharedConversation>,
53    event_bus: Option<Arc<dyn EventBus>>,
54}
55
56impl MultiAgentOrchestrator {
57    /// Starts building an orchestrator.
58    pub fn builder() -> MultiAgentOrchestratorBuilder {
59        MultiAgentOrchestratorBuilder {
60            agents: HashMap::new(),
61            order: Vec::new(),
62            routing: None,
63            shared_memory: None,
64            event_bus: None,
65        }
66    }
67
68    /// Runs the configured workflow.
69    pub async fn run(&mut self, input: &str) -> Result<AgentOutput> {
70        match self.routing.clone() {
71            RoutingStrategy::Sequential(names) => self.run_sequential(input, names).await,
72            RoutingStrategy::Parallel(names) => self.run_parallel(input, names).await,
73            RoutingStrategy::Supervisor {
74                llm,
75                agents,
76                max_turns,
77            } => self.run_supervisor(input, llm, agents, max_turns).await,
78            RoutingStrategy::Graph(graph) => self.run_graph(input, &graph).await,
79        }
80    }
81
82    /// Registers or replaces a named agent after construction.
83    pub fn add_agent_boxed(&mut self, name: impl Into<String>, agent: Box<dyn AgentTrait>) {
84        self.agents.insert(name.into(), Arc::new(Mutex::new(agent)));
85    }
86
87    async fn run_sequential(&mut self, input: &str, names: Vec<String>) -> Result<AgentOutput> {
88        let mut current_input = input.to_string();
89        let mut final_output = None;
90
91        for name in names {
92            let output = self.run_named_agent(&name, &current_input).await?;
93            current_input = output.text.clone();
94            final_output = Some(output);
95        }
96
97        final_output
98            .ok_or_else(|| AgentError::InvalidConfiguration("no agents configured".to_string()))
99    }
100
101    async fn run_parallel(&mut self, input: &str, names: Vec<String>) -> Result<AgentOutput> {
102        let futures = names.into_iter().map(|name| {
103            let agent = self
104                .agents
105                .get(&name)
106                .cloned()
107                .ok_or_else(|| AgentError::AgentNotFound(name.clone()));
108            let event_bus = self.event_bus.clone();
109            let shared_memory = self.shared_memory.clone();
110            let input = input.to_string();
111            async move {
112                let agent = agent?;
113                let mut agent = agent.lock().await;
114                let output = agent.run(&input).await?;
115                if let Some(shared_memory) = shared_memory {
116                    if let Some(last_message) = output.messages.last().cloned() {
117                        shared_memory.add(&name, last_message).await?;
118                    }
119                }
120                if let Some(event_bus) = event_bus {
121                    event_bus
122                        .publish(OrchestratorEvent::AgentCompleted {
123                            agent: name.clone(),
124                            output: output.clone(),
125                        })
126                        .await?;
127                }
128                Ok::<_, AgentError>((name, output))
129            }
130        });
131
132        let results = futures::future::try_join_all(futures).await?;
133        let mut messages = Vec::new();
134        let mut text = String::new();
135        for (name, output) in results {
136            if !text.is_empty() {
137                text.push_str("\n\n");
138            }
139            text.push_str(&format!("[{name}]\n{}", output.text));
140            messages.extend(output.messages);
141        }
142
143        Ok(AgentOutput {
144            text,
145            steps: 1,
146            usage: Default::default(),
147            messages,
148            metadata: HashMap::new(),
149        })
150    }
151
152    async fn run_supervisor(
153        &mut self,
154        input: &str,
155        llm: Arc<dyn LlmProvider>,
156        agents: Vec<String>,
157        max_turns: usize,
158    ) -> Result<AgentOutput> {
159        let agent_lines = agents
160            .iter()
161            .map(|name| format!("- {name}"))
162            .collect::<Vec<_>>()
163            .join("\n");
164        let mut context = format!(
165            "You are a supervisor. Available agents:\n{agent_lines}\n\nReturn JSON {{\"agent\": \"name\"}} for the best agent to handle the task. Task: {input}"
166        );
167
168        for _ in 0..max_turns {
169            let response = llm
170                .complete(CompletionRequest {
171                    messages: vec![Message::user(context.clone())],
172                    tools: None,
173                    model: String::new(),
174                    temperature: Some(0.0),
175                    max_tokens: Some(256),
176                    stream: false,
177                    system: None,
178                    extra: HashMap::new(),
179                })
180                .await?;
181            let choice: serde_json::Value = serde_json::from_str(&response.message.text_content())?;
182            if let Some(agent_name) = choice.get("agent").and_then(serde_json::Value::as_str) {
183                return self.run_named_agent(agent_name, input).await;
184            }
185            context = format!(
186                "{context}\nPrevious response was invalid JSON: {}",
187                response.message.text_content()
188            );
189        }
190
191        Err(AgentError::MaxStepsReached { steps: max_turns })
192    }
193
194    async fn run_graph(&mut self, input: &str, graph: &AgentGraph) -> Result<AgentOutput> {
195        let mut current = graph.entry()?.to_string();
196        let mut current_input = input.to_string();
197
198        loop {
199            let output = self.run_named_agent(&current, &current_input).await?;
200            if let Some(next) = graph.next(&current, &output) {
201                current = next;
202                current_input = output.text.clone();
203                continue;
204            }
205            return Ok(output);
206        }
207    }
208
209    async fn run_named_agent(&mut self, name: &str, input: &str) -> Result<AgentOutput> {
210        let agent = self
211            .agents
212            .get(name)
213            .cloned()
214            .ok_or_else(|| AgentError::AgentNotFound(name.to_string()))?;
215
216        let mut agent = agent.lock().await;
217        let output = agent.run(input).await?;
218
219        if let Some(shared_memory) = &self.shared_memory {
220            if let Some(last_message) = output.messages.last().cloned() {
221                shared_memory.add(name, last_message).await?;
222            }
223        }
224        if let Some(event_bus) = &self.event_bus {
225            event_bus
226                .publish(OrchestratorEvent::AgentCompleted {
227                    agent: name.to_string(),
228                    output: output.clone(),
229                })
230                .await?;
231        }
232
233        Ok(output)
234    }
235}
236
237impl MultiAgentOrchestratorBuilder {
238    /// Registers an agent by name.
239    pub fn add_agent(mut self, name: impl Into<String>, agent: impl AgentTrait + 'static) -> Self {
240        let name = name.into();
241        self.order.push(name.clone());
242        self.agents
243            .insert(name, Arc::new(Mutex::new(Box::new(agent))));
244        self
245    }
246
247    /// Registers a boxed agent by name.
248    pub fn add_agent_boxed(mut self, name: impl Into<String>, agent: Box<dyn AgentTrait>) -> Self {
249        let name = name.into();
250        self.order.push(name.clone());
251        self.agents.insert(name, Arc::new(Mutex::new(agent)));
252        self
253    }
254
255    /// Sets the routing strategy.
256    pub fn routing(mut self, routing: RoutingStrategy) -> Self {
257        self.routing = Some(routing);
258        self
259    }
260
261    /// Enables shared conversation memory.
262    pub fn shared_memory(mut self, shared_memory: SharedConversation) -> Self {
263        self.shared_memory = Some(shared_memory);
264        self
265    }
266
267    /// Enables orchestration events.
268    pub fn event_bus(mut self, event_bus: Arc<dyn EventBus>) -> Self {
269        self.event_bus = Some(event_bus);
270        self
271    }
272
273    /// Builds the orchestrator.
274    pub fn build(self) -> Result<MultiAgentOrchestrator> {
275        let routing = self
276            .routing
277            .unwrap_or_else(|| RoutingStrategy::Sequential(self.order.clone()));
278        Ok(MultiAgentOrchestrator {
279            agents: self.agents,
280            routing,
281            shared_memory: self.shared_memory,
282            event_bus: self.event_bus,
283        })
284    }
285}