1#![allow(dead_code)]
12
13use crate::agency::agent::Agent;
14use crate::agency::error::{AgencyError, AgencyResult};
15use crate::agency::executor::{ExecutionContext, ExecutionResult, Executor};
16use crate::agency::models::{AgencyEvent, EventType, TokenUsage};
17use crate::agency::session::Session;
18use chrono::Utc;
19use serde::{Deserialize, Serialize};
20use std::sync::Arc;
21
22#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
24#[serde(rename_all = "snake_case")]
25pub enum OrchestrationType {
26 Sequential,
28 Parallel,
30 Loop,
32 Hierarchical,
34}
35
36#[derive(Debug, Clone)]
38pub struct Pipeline {
39 pub name: String,
41 pub orchestration: OrchestrationType,
43 pub agents: Vec<Arc<Agent>>,
45 pub max_iterations: u32,
47}
48
49impl Pipeline {
50 pub fn sequential(name: impl Into<String>, agents: Vec<Agent>) -> Self {
52 Self {
53 name: name.into(),
54 orchestration: OrchestrationType::Sequential,
55 agents: agents.into_iter().map(Arc::new).collect(),
56 max_iterations: 1,
57 }
58 }
59
60 pub fn parallel(name: impl Into<String>, agents: Vec<Agent>) -> Self {
62 Self {
63 name: name.into(),
64 orchestration: OrchestrationType::Parallel,
65 agents: agents.into_iter().map(Arc::new).collect(),
66 max_iterations: 1,
67 }
68 }
69
70 pub fn loop_agent(name: impl Into<String>, agent: Agent, max_iterations: u32) -> Self {
72 Self {
73 name: name.into(),
74 orchestration: OrchestrationType::Loop,
75 agents: vec![Arc::new(agent)],
76 max_iterations,
77 }
78 }
79}
80
81#[derive(Debug, Clone)]
83pub struct Swarm {
84 pub name: String,
86 pub description: String,
88 pub coordinator: Arc<Agent>,
90 pub workers: Vec<Arc<Agent>>,
92 pub goal: Option<String>,
94}
95
96impl Swarm {
97 pub fn new(
99 name: impl Into<String>,
100 description: impl Into<String>,
101 coordinator: Agent,
102 workers: Vec<Agent>,
103 ) -> Self {
104 Self {
105 name: name.into(),
106 description: description.into(),
107 coordinator: Arc::new(coordinator),
108 workers: workers.into_iter().map(Arc::new).collect(),
109 goal: None,
110 }
111 }
112
113 pub fn with_goal(mut self, goal: impl Into<String>) -> Self {
115 self.goal = Some(goal.into());
116 self
117 }
118}
119
120pub struct Orchestrator {
122 executor: Arc<Executor>,
123}
124
125impl Orchestrator {
126 pub fn new(executor: Arc<Executor>) -> Self {
128 Self { executor }
129 }
130
131 pub async fn run_pipeline(
133 &self,
134 pipeline: &Pipeline,
135 input: &str,
136 ctx: &mut ExecutionContext,
137 ) -> AgencyResult<OrchestratorResult> {
138 match pipeline.orchestration {
139 OrchestrationType::Sequential => self.run_sequential(pipeline, input, ctx).await,
140 OrchestrationType::Parallel => self.run_parallel(pipeline, input, ctx).await,
141 OrchestrationType::Loop => self.run_loop(pipeline, input, ctx).await,
142 OrchestrationType::Hierarchical => Err(AgencyError::OrchestrationError(
143 "Use run_swarm for hierarchical orchestration".to_string(),
144 )),
145 }
146 }
147
148 async fn run_sequential(
150 &self,
151 pipeline: &Pipeline,
152 input: &str,
153 ctx: &mut ExecutionContext,
154 ) -> AgencyResult<OrchestratorResult> {
155 let start_time = std::time::Instant::now();
156 let mut results = Vec::new();
157 let mut events = Vec::new();
158 let mut token_usage = TokenUsage::default();
159 let mut current_input = input.to_string();
160
161 for agent_arc in &pipeline.agents {
162 let agent = agent_arc.as_ref();
163 let mut session = Session::new(agent.name(), ctx.user_id.clone());
164
165 let result = self
166 .executor
167 .execute(agent, &mut session, ¤t_input, ctx)
168 .await?;
169
170 current_input = result.response.clone();
172 token_usage.add(&result.token_usage);
173 events.extend(result.events.clone());
174 results.push(result);
175 }
176
177 let final_response = results
178 .last()
179 .map(|r| r.response.clone())
180 .unwrap_or_default();
181
182 Ok(OrchestratorResult {
183 response: final_response,
184 agent_results: results,
185 events,
186 token_usage,
187 duration_ms: start_time.elapsed().as_millis() as u64,
188 iterations: 1,
189 })
190 }
191
192 async fn run_parallel(
194 &self,
195 pipeline: &Pipeline,
196 input: &str,
197 ctx: &mut ExecutionContext,
198 ) -> AgencyResult<OrchestratorResult> {
199 let start_time = std::time::Instant::now();
200 let mut handles = Vec::new();
201
202 for agent_arc in &pipeline.agents {
203 let agent = agent_arc.clone();
204 let executor = self.executor.clone();
205 let input = input.to_string();
206 let user_id = ctx.user_id.clone();
207
208 handles.push(tokio::spawn(async move {
209 let mut session = Session::new(agent.name(), user_id.clone());
210 let mut ctx = ExecutionContext::new(&session);
211 ctx.user_id = user_id;
212
213 executor
214 .execute(agent.as_ref(), &mut session, &input, &mut ctx)
215 .await
216 }));
217 }
218
219 let mut results = Vec::new();
220 let mut events = Vec::new();
221 let mut token_usage = TokenUsage::default();
222 let mut responses = Vec::new();
223
224 for handle in handles {
225 match handle.await {
226 Ok(Ok(result)) => {
227 responses.push(result.response.clone());
228 token_usage.add(&result.token_usage);
229 events.extend(result.events.clone());
230 results.push(result);
231 }
232 Ok(Err(e)) => {
233 return Err(e);
234 }
235 Err(e) => {
236 return Err(AgencyError::ExecutionFailed(e.to_string()));
237 }
238 }
239 }
240
241 let final_response = responses.join("\n\n---\n\n");
243
244 Ok(OrchestratorResult {
245 response: final_response,
246 agent_results: results,
247 events,
248 token_usage,
249 duration_ms: start_time.elapsed().as_millis() as u64,
250 iterations: 1,
251 })
252 }
253
254 async fn run_loop(
256 &self,
257 pipeline: &Pipeline,
258 input: &str,
259 ctx: &mut ExecutionContext,
260 ) -> AgencyResult<OrchestratorResult> {
261 let start_time = std::time::Instant::now();
262 let mut results = Vec::new();
263 let mut events = Vec::new();
264 let mut token_usage = TokenUsage::default();
265 let mut current_input = input.to_string();
266 let mut iterations = 0;
267
268 let agent_arc = pipeline.agents.first().ok_or_else(|| {
269 AgencyError::OrchestrationError("Loop pipeline requires at least one agent".to_string())
270 })?;
271
272 loop {
273 iterations += 1;
274 if iterations > pipeline.max_iterations {
275 break;
276 }
277
278 let agent = agent_arc.as_ref();
279 let mut session = Session::new(agent.name(), ctx.user_id.clone());
280
281 let result = self
282 .executor
283 .execute(agent, &mut session, ¤t_input, ctx)
284 .await?;
285
286 token_usage.add(&result.token_usage);
287 events.extend(result.events.clone());
288 results.push(result.clone());
289
290 if result.response.contains("DONE")
293 || result.response.contains("COMPLETE")
294 || result.response.contains("FINISHED")
295 {
296 break;
297 }
298
299 current_input = result.response;
301 }
302
303 let final_response = results
304 .last()
305 .map(|r| r.response.clone())
306 .unwrap_or_default();
307
308 Ok(OrchestratorResult {
309 response: final_response,
310 agent_results: results,
311 events,
312 token_usage,
313 duration_ms: start_time.elapsed().as_millis() as u64,
314 iterations,
315 })
316 }
317
318 pub async fn run_swarm(
320 &self,
321 swarm: &Swarm,
322 input: &str,
323 ctx: &mut ExecutionContext,
324 ) -> AgencyResult<OrchestratorResult> {
325 let start_time = std::time::Instant::now();
326 let mut results = Vec::new();
327 let mut events = Vec::new();
328 let mut token_usage = TokenUsage::default();
329
330 let coordinator = swarm.coordinator.as_ref();
332 let mut coord_session = Session::new(coordinator.name(), ctx.user_id.clone());
333
334 let worker_info: Vec<_> = swarm
336 .workers
337 .iter()
338 .map(|w| format!("- {}: {}", w.name(), w.description()))
339 .collect();
340
341 let coordinator_input = format!(
342 "Task: {}\n\nAvailable workers:\n{}\n\nAnalyze the task and delegate to appropriate workers.",
343 input,
344 worker_info.join("\n")
345 );
346
347 let coord_result = self
348 .executor
349 .execute(coordinator, &mut coord_session, &coordinator_input, ctx)
350 .await?;
351
352 token_usage.add(&coord_result.token_usage);
353 events.extend(coord_result.events.clone());
354 results.push(coord_result.clone());
355
356 let handoff_event = AgencyEvent {
358 event_type: EventType::Handoff,
359 agent_name: coordinator.name().to_string(),
360 data: serde_json::json!({
361 "from": coordinator.name(),
362 "task": input
363 }),
364 timestamp: Utc::now(),
365 session_id: Some(coord_session.id.clone()),
366 };
367 events.push(handoff_event.clone());
368 ctx.emit(handoff_event).await;
369
370 for worker_arc in &swarm.workers {
373 let worker = worker_arc.as_ref();
374 let mut worker_session = Session::new(worker.name(), ctx.user_id.clone());
375
376 let worker_result = self
377 .executor
378 .execute(worker, &mut worker_session, input, ctx)
379 .await?;
380
381 token_usage.add(&worker_result.token_usage);
382 events.extend(worker_result.events.clone());
383 results.push(worker_result);
384 }
385
386 let worker_results: Vec<_> = results
388 .iter()
389 .skip(1) .map(|r| format!("Result: {}", r.response))
391 .collect();
392
393 let synthesis_input = format!(
394 "Original task: {}\n\nWorker results:\n{}\n\nSynthesize these results into a final response.",
395 input,
396 worker_results.join("\n\n")
397 );
398
399 let final_result = self
400 .executor
401 .execute(coordinator, &mut coord_session, &synthesis_input, ctx)
402 .await?;
403
404 token_usage.add(&final_result.token_usage);
405 events.extend(final_result.events.clone());
406 results.push(final_result.clone());
407
408 Ok(OrchestratorResult {
409 response: final_result.response,
410 agent_results: results,
411 events,
412 token_usage,
413 duration_ms: start_time.elapsed().as_millis() as u64,
414 iterations: 1,
415 })
416 }
417}
418
419#[derive(Debug, Clone, Serialize, Deserialize)]
421pub struct OrchestratorResult {
422 pub response: String,
424 pub agent_results: Vec<ExecutionResult>,
426 pub events: Vec<AgencyEvent>,
428 pub token_usage: TokenUsage,
430 pub duration_ms: u64,
432 pub iterations: u32,
434}
435
436#[cfg(test)]
437mod tests {
438 use super::*;
439 use crate::agency::agent::AgentBuilder;
440 use crate::agency::tools::ToolRegistry;
441
442 fn create_test_agent(name: &str) -> Agent {
443 AgentBuilder::new(name)
444 .description(format!("{} agent", name))
445 .instruction("You are a helpful assistant.")
446 .model("gemini-2.5-flash")
447 .build()
448 }
449
450 #[tokio::test]
451 #[ignore = "Integration test - requires API credentials"]
452 async fn test_sequential_pipeline() {
453 let tool_registry = Arc::new(ToolRegistry::new());
454 let executor = Arc::new(Executor::new(tool_registry));
455 let orchestrator = Orchestrator::new(executor);
456
457 let agents = vec![create_test_agent("researcher"), create_test_agent("writer")];
458 let pipeline = Pipeline::sequential("research_pipeline", agents);
459
460 let session = Session::new("test", None);
461 let mut ctx = ExecutionContext::new(&session);
462
463 let result = orchestrator
464 .run_pipeline(&pipeline, "Tell me about Rust", &mut ctx)
465 .await
466 .unwrap();
467
468 assert!(!result.response.is_empty());
469 assert_eq!(result.agent_results.len(), 2);
470 }
471
472 #[tokio::test]
473 #[ignore = "Integration test - requires API credentials"]
474 async fn test_parallel_pipeline() {
475 let tool_registry = Arc::new(ToolRegistry::new());
476 let executor = Arc::new(Executor::new(tool_registry));
477 let orchestrator = Orchestrator::new(executor);
478
479 let agents = vec![create_test_agent("analyst1"), create_test_agent("analyst2")];
480 let pipeline = Pipeline::parallel("analysis_pipeline", agents);
481
482 let session = Session::new("test", None);
483 let mut ctx = ExecutionContext::new(&session);
484
485 let result = orchestrator
486 .run_pipeline(&pipeline, "Analyze this data", &mut ctx)
487 .await
488 .unwrap();
489
490 assert!(!result.response.is_empty());
491 assert_eq!(result.agent_results.len(), 2);
492 }
493}