1use crate::agent::{Agent, AgentError};
7use crate::agent_loop::{LoopConfig, run_loop};
8use crate::context::AgentContext;
9use crate::registry::ToolRegistry;
10use crate::types::Message;
11use std::collections::HashMap;
12use std::fmt;
13use std::path::PathBuf;
14use std::sync::Arc;
15use tokio::sync::{Mutex, mpsc, oneshot};
16use tokio_util::sync::CancellationToken;
17
18#[derive(Debug, Clone, Hash, PartialEq, Eq)]
20pub struct AgentId(pub String);
21
22impl Default for AgentId {
23 fn default() -> Self {
24 Self(format!("agent-{}", next_id()))
25 }
26}
27
28impl AgentId {
29 pub fn new() -> Self {
30 Self::default()
31 }
32
33 pub fn short(&self) -> &str {
34 &self.0
35 }
36}
37
38impl fmt::Display for AgentId {
39 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
40 write!(f, "{}", self.0)
41 }
42}
43
44fn next_id() -> u64 {
45 use std::sync::atomic::{AtomicU64, Ordering};
46 static COUNTER: AtomicU64 = AtomicU64::new(1);
47 COUNTER.fetch_add(1, Ordering::Relaxed)
48}
49
50#[derive(Debug, Clone, PartialEq, Eq)]
52pub enum AgentRole {
53 Explorer,
55 Worker,
57 Reviewer,
59 Custom(String),
61}
62
63impl AgentRole {
64 pub fn name(&self) -> &str {
65 match self {
66 Self::Explorer => "explorer",
67 Self::Worker => "worker",
68 Self::Reviewer => "reviewer",
69 Self::Custom(n) => n,
70 }
71 }
72}
73
74impl fmt::Display for AgentRole {
75 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
76 write!(f, "{}", self.name())
77 }
78}
79
80#[derive(Debug, Clone, PartialEq, Eq)]
82pub enum AgentStatus {
83 Running,
84 Completed,
85 Failed(String),
86 Cancelled,
87}
88
89impl fmt::Display for AgentStatus {
90 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
91 match self {
92 Self::Running => write!(f, "running"),
93 Self::Completed => write!(f, "completed"),
94 Self::Failed(e) => write!(f, "failed: {}", e),
95 Self::Cancelled => write!(f, "cancelled"),
96 }
97 }
98}
99
100#[derive(Debug, Clone)]
102pub struct SwarmResult {
103 pub id: AgentId,
104 pub role: AgentRole,
105 pub status: AgentStatus,
106 pub summary: String,
108 pub steps: usize,
110 pub events: Vec<String>,
112}
113
114pub struct SpawnConfig {
116 pub role: AgentRole,
118 pub system_prompt: Option<String>,
120 pub tool_names: Option<Vec<String>>,
122 pub cwd: Option<PathBuf>,
124 pub task: String,
126 pub max_steps: usize,
128 pub writable_roots: Option<Vec<PathBuf>>,
130}
131
132impl SpawnConfig {
133 pub fn explorer(task: impl Into<String>) -> Self {
134 Self {
135 role: AgentRole::Explorer,
136 system_prompt: None,
137 tool_names: None,
138 cwd: None,
139 task: task.into(),
140 max_steps: 10,
141 writable_roots: None,
142 }
143 }
144
145 pub fn worker(task: impl Into<String>) -> Self {
146 Self {
147 role: AgentRole::Worker,
148 system_prompt: None,
149 tool_names: None,
150 cwd: None,
151 task: task.into(),
152 max_steps: 30,
153 writable_roots: None,
154 }
155 }
156
157 pub fn reviewer(task: impl Into<String>) -> Self {
158 Self {
159 role: AgentRole::Reviewer,
160 system_prompt: None,
161 tool_names: None,
162 cwd: None,
163 task: task.into(),
164 max_steps: 15,
165 writable_roots: None,
166 }
167 }
168}
169
170#[derive(Debug, thiserror::Error)]
172pub enum SwarmError {
173 #[error("Max agents reached ({0})")]
174 MaxAgents(usize),
175 #[error("Max depth reached ({0})")]
176 MaxDepth(usize),
177 #[error("Agent not found: {0}")]
178 NotFound(AgentId),
179 #[error("Agent already completed: {0}")]
180 AlreadyCompleted(AgentId),
181 #[error("Agent error: {0}")]
182 Agent(#[from] AgentError),
183 #[error("Channel error")]
184 Channel,
185}
186
187struct AgentHandle {
189 id: AgentId,
190 role: AgentRole,
191 cancel: CancellationToken,
192 status: Arc<Mutex<AgentStatus>>,
193 result_rx: Option<oneshot::Receiver<SwarmResult>>,
194}
195
196#[derive(Debug, Clone)]
198pub struct AgentNotification {
199 pub id: AgentId,
200 pub role: AgentRole,
201 pub status: AgentStatus,
202 pub summary: String,
203}
204
205pub struct SwarmManager {
207 agents: HashMap<AgentId, AgentHandle>,
208 notification_tx: mpsc::Sender<AgentNotification>,
210 notification_rx: Arc<Mutex<mpsc::Receiver<AgentNotification>>>,
211 max_agents: usize,
212 max_depth: usize,
213 current_depth: usize,
214}
215
216impl SwarmManager {
217 pub fn new() -> Self {
218 let (tx, rx) = mpsc::channel(64);
219 Self {
220 agents: HashMap::new(),
221 notification_tx: tx,
222 notification_rx: Arc::new(Mutex::new(rx)),
223 max_agents: 8,
224 max_depth: 3,
225 current_depth: 0,
226 }
227 }
228
229 pub fn with_limits(mut self, max_agents: usize, max_depth: usize) -> Self {
230 self.max_agents = max_agents;
231 self.max_depth = max_depth;
232 self
233 }
234
235 pub fn with_depth(mut self, depth: usize) -> Self {
236 self.current_depth = depth;
237 self
238 }
239
240 pub fn spawn(
245 &mut self,
246 config: SpawnConfig,
247 agent: Box<dyn Agent>,
248 tools: ToolRegistry,
249 parent_ctx: &AgentContext,
250 ) -> Result<AgentId, SwarmError> {
251 if self.active_count() >= self.max_agents {
252 return Err(SwarmError::MaxAgents(self.max_agents));
253 }
254 if self.current_depth >= self.max_depth {
255 return Err(SwarmError::MaxDepth(self.max_depth));
256 }
257
258 let id = AgentId::new();
259 let cancel = CancellationToken::new();
260 let status = Arc::new(Mutex::new(AgentStatus::Running));
261 let (result_tx, result_rx) = oneshot::channel();
262
263 let mut ctx = AgentContext::new();
265 ctx.cwd = config.cwd.unwrap_or_else(|| parent_ctx.cwd.clone());
266 ctx.writable_roots = config
267 .writable_roots
268 .unwrap_or_else(|| parent_ctx.writable_roots.clone());
269
270 let system_prompt = config.system_prompt.unwrap_or_else(|| {
272 format!(
273 "You are a {} agent. Complete the assigned task efficiently.",
274 config.role.name()
275 )
276 });
277 let mut messages = vec![Message::system(&system_prompt), Message::user(&config.task)];
278
279 let loop_config = LoopConfig {
280 max_steps: config.max_steps,
281 ..Default::default()
282 };
283
284 let agent_id = id.clone();
285 let agent_role = config.role.clone();
286 let cancel_token = cancel.clone();
287 let status_clone = Arc::clone(&status);
288 let notify_tx = self.notification_tx.clone();
289
290 tokio::spawn(async move {
292 let mut events: Vec<String> = Vec::new();
293
294 let loop_result = tokio::select! {
295 result = run_loop(
296 agent.as_ref(),
297 &tools,
298 &mut ctx,
299 &mut messages,
300 &loop_config,
301 |event| {
302 events.push(format!("{:?}", event));
303 },
304 ) => result,
305 _ = cancel_token.cancelled() => {
306 Err(AgentError::Cancelled)
307 }
308 };
309
310 let (final_status, summary, steps) = match loop_result {
311 Ok(steps) => {
312 let summary = messages
313 .iter()
314 .rev()
315 .find(|m| m.role == crate::types::Role::Assistant)
316 .map(|m| m.content.clone())
317 .unwrap_or_else(|| "Completed".to_string());
318 (AgentStatus::Completed, summary, steps)
319 }
320 Err(AgentError::Cancelled) => (AgentStatus::Cancelled, "Cancelled".to_string(), 0),
321 Err(e) => (AgentStatus::Failed(e.to_string()), e.to_string(), 0),
322 };
323
324 *status_clone.lock().await = final_status.clone();
326
327 let result = SwarmResult {
328 id: agent_id.clone(),
329 role: agent_role.clone(),
330 status: final_status.clone(),
331 summary: summary.clone(),
332 steps,
333 events,
334 };
335
336 let _ = result_tx.send(result);
338
339 let _ = notify_tx
341 .send(AgentNotification {
342 id: agent_id,
343 role: agent_role,
344 status: final_status,
345 summary,
346 })
347 .await;
348 });
349
350 self.agents.insert(
351 id.clone(),
352 AgentHandle {
353 id: id.clone(),
354 role: config.role,
355 cancel,
356 status,
357 result_rx: Some(result_rx),
358 },
359 );
360
361 Ok(id)
362 }
363
364 pub async fn status(&self, id: &AgentId) -> Option<AgentStatus> {
366 if let Some(handle) = self.agents.get(id) {
367 Some(handle.status.lock().await.clone())
368 } else {
369 None
370 }
371 }
372
373 pub async fn status_all(&self) -> Vec<(AgentId, AgentRole, AgentStatus)> {
375 let mut result = Vec::new();
376 for handle in self.agents.values() {
377 let status = handle.status.lock().await.clone();
378 result.push((handle.id.clone(), handle.role.clone(), status));
379 }
380 result
381 }
382
383 pub fn take_receiver(
386 &mut self,
387 id: &AgentId,
388 ) -> Result<oneshot::Receiver<SwarmResult>, SwarmError> {
389 let handle = self
390 .agents
391 .get_mut(id)
392 .ok_or_else(|| SwarmError::NotFound(id.clone()))?;
393
394 handle
395 .result_rx
396 .take()
397 .ok_or_else(|| SwarmError::AlreadyCompleted(id.clone()))
398 }
399
400 pub fn take_all_receivers(&mut self) -> Vec<(AgentId, oneshot::Receiver<SwarmResult>)> {
402 let mut receivers = Vec::new();
403 for (id, handle) in &mut self.agents {
404 if let Some(rx) = handle.result_rx.take() {
405 receivers.push((id.clone(), rx));
406 }
407 }
408 receivers
409 }
410
411 pub async fn wait(&mut self, id: &AgentId) -> Result<SwarmResult, SwarmError> {
414 let rx = self.take_receiver(id)?;
415 let result = rx.await.map_err(|_| SwarmError::Channel)?;
416 self.agents.remove(id); Ok(result)
418 }
419
420 pub async fn wait_all(&mut self) -> Vec<SwarmResult> {
423 let receivers = self.take_all_receivers();
424 let mut results = Vec::new();
425 for (id, rx) in receivers {
426 if let Ok(result) = rx.await {
427 results.push(result);
428 self.agents.remove(&id);
429 }
430 }
431 results
432 }
433
434 pub fn cancel(&self, id: &AgentId) -> Result<(), SwarmError> {
436 let handle = self
437 .agents
438 .get(id)
439 .ok_or_else(|| SwarmError::NotFound(id.clone()))?;
440 handle.cancel.cancel();
441 Ok(())
442 }
443
444 pub fn cancel_all(&self) {
446 for handle in self.agents.values() {
447 handle.cancel.cancel();
448 }
449 }
450
451 pub async fn try_recv_notification(&self) -> Option<AgentNotification> {
453 let mut rx = self.notification_rx.lock().await;
454 rx.try_recv().ok()
455 }
456
457 pub async fn recv_notification(
459 &self,
460 timeout: std::time::Duration,
461 ) -> Option<AgentNotification> {
462 let mut rx = self.notification_rx.lock().await;
463 tokio::time::timeout(timeout, rx.recv())
464 .await
465 .ok()
466 .flatten()
467 }
468
469 pub fn cleanup(&mut self, id: &AgentId) {
471 self.agents.remove(id);
472 }
473
474 pub fn agent_count(&self) -> usize {
476 self.agents.len()
477 }
478
479 pub fn active_count(&self) -> usize {
481 self.agents
482 .values()
483 .filter(|h| h.result_rx.is_some())
484 .count()
485 }
486
487 pub fn all_agent_ids(&self) -> Vec<AgentId> {
489 self.agents.keys().cloned().collect()
490 }
491
492 pub async fn status_all_formatted(&self) -> String {
494 let statuses = self.status_all().await;
495 if statuses.is_empty() {
496 return "No agents.".to_string();
497 }
498 statuses
499 .iter()
500 .map(|(id, role, status)| format!("[{}] {} — {}", id, role, status))
501 .collect::<Vec<_>>()
502 .join("\n")
503 }
504
505 pub async fn wait_with_timeout(
507 &mut self,
508 ids: &[AgentId],
509 timeout: std::time::Duration,
510 ) -> Vec<(AgentId, String)> {
511 let mut results = Vec::new();
512 for id in ids {
513 let rx = match self.take_receiver(id) {
514 Ok(rx) => rx,
515 Err(e) => {
516 results.push((id.clone(), format!("Error: {}", e)));
517 continue;
518 }
519 };
520 match tokio::time::timeout(timeout, rx).await {
521 Ok(Ok(result)) => {
522 let summary = format!(
523 "{} ({}, {} steps): {}",
524 result.status,
525 result.role,
526 result.steps,
527 {
528 use crate::str_ext::StrExt;
529 if result.summary.len() > 500 {
530 format!("{}...", result.summary.trunc(500))
531 } else {
532 result.summary.clone()
533 }
534 }
535 );
536 self.agents.remove(id);
537 results.push((id.clone(), summary));
538 }
539 Ok(Err(_)) => {
540 results.push((id.clone(), "Channel closed".into()));
541 }
542 Err(_) => {
543 results.push((id.clone(), format!("Timeout after {}s", timeout.as_secs())));
544 }
545 }
546 }
547 results
548 }
549
550 pub async fn status_summary(&self) -> String {
552 let mut lines = Vec::new();
553 for handle in self.agents.values() {
554 let status = handle.status.lock().await;
555 lines.push(format!(" {} ({}) — {}", handle.id, handle.role, *status));
556 }
557 if lines.is_empty() {
558 " (none)".to_string()
559 } else {
560 lines.join("\n")
561 }
562 }
563}
564
565impl Default for SwarmManager {
566 fn default() -> Self {
567 Self::new()
568 }
569}
570
571#[cfg(test)]
572mod tests {
573 use super::*;
574 use crate::agent::{Agent, AgentError, Decision};
575 use crate::agent_tool::{Tool, ToolError, ToolOutput};
576 use crate::types::{Message, ToolCall};
577 use serde_json::Value;
578
579 struct SimpleAgent {}
580
581 #[async_trait::async_trait]
582 impl Agent for SimpleAgent {
583 async fn decide(
584 &self,
585 _messages: &[Message],
586 _tools: &ToolRegistry,
587 ) -> Result<Decision, AgentError> {
588 Ok(Decision {
590 situation: "Task done.".into(),
591 task: vec![],
592 tool_calls: vec![],
593 completed: true,
594 })
595 }
596 }
597
598 struct StepAgent {
599 steps: usize,
600 }
601
602 #[async_trait::async_trait]
603 impl Agent for StepAgent {
604 async fn decide(
605 &self,
606 msgs: &[Message],
607 _tools: &ToolRegistry,
608 ) -> Result<Decision, AgentError> {
609 let tool_msgs = msgs
611 .iter()
612 .filter(|m| m.role == crate::types::Role::Tool)
613 .count();
614 if tool_msgs >= self.steps {
615 Ok(Decision {
616 situation: "All steps done.".into(),
617 task: vec![],
618 tool_calls: vec![],
619 completed: true,
620 })
621 } else {
622 Ok(Decision {
623 situation: format!("Step {}", tool_msgs + 1),
624 task: vec![],
625 tool_calls: vec![ToolCall {
626 id: format!("call_{}", tool_msgs),
627 name: "echo".into(),
628 arguments: serde_json::json!({}),
629 }],
630 completed: false,
631 })
632 }
633 }
634 }
635
636 struct EchoTool;
637
638 #[async_trait::async_trait]
639 impl Tool for EchoTool {
640 fn name(&self) -> &str {
641 "echo"
642 }
643 fn description(&self) -> &str {
644 "echo"
645 }
646 fn parameters_schema(&self) -> Value {
647 serde_json::json!({"type": "object"})
648 }
649 async fn execute(&self, _: Value, _: &mut AgentContext) -> Result<ToolOutput, ToolError> {
650 Ok(ToolOutput::text("echoed"))
651 }
652 }
653
654 #[tokio::test]
655 async fn spawn_and_wait() {
656 let mut swarm = SwarmManager::new();
657 let ctx = AgentContext::new();
658
659 let id = swarm
660 .spawn(
661 SpawnConfig::explorer("Find all Rust files"),
662 Box::new(SimpleAgent {}),
663 ToolRegistry::new(),
664 &ctx,
665 )
666 .unwrap();
667
668 let result = swarm.wait(&id).await.unwrap();
669 assert_eq!(result.status, AgentStatus::Completed);
670 assert!(result.summary.contains("Task done"));
671 }
672
673 #[tokio::test]
674 async fn spawn_with_tools() {
675 let mut swarm = SwarmManager::new();
676 let ctx = AgentContext::new();
677 let tools = ToolRegistry::new().register(EchoTool);
678
679 let id = swarm
680 .spawn(
681 SpawnConfig::worker("Do 2 steps"),
682 Box::new(StepAgent { steps: 2 }),
683 tools,
684 &ctx,
685 )
686 .unwrap();
687
688 let result = swarm.wait(&id).await.unwrap();
689 assert_eq!(result.status, AgentStatus::Completed);
690 assert!(result.steps >= 2);
691 }
692
693 #[tokio::test]
694 async fn cancel_agent() {
695 let mut swarm = SwarmManager::new();
696 let ctx = AgentContext::new();
697
698 let id = swarm
700 .spawn(
701 SpawnConfig {
702 role: AgentRole::Worker,
703 system_prompt: None,
704 tool_names: None,
705 cwd: None,
706 task: "Long task".into(),
707 max_steps: 100,
708 writable_roots: None,
709 },
710 Box::new(StepAgent { steps: 100 }),
711 ToolRegistry::new().register(EchoTool),
712 &ctx,
713 )
714 .unwrap();
715
716 swarm.cancel(&id).unwrap();
718
719 let result = swarm.wait(&id).await.unwrap();
720 assert!(
721 result.status == AgentStatus::Cancelled
722 || matches!(result.status, AgentStatus::Failed(_))
723 || result.status == AgentStatus::Completed );
725 }
726
727 #[tokio::test]
728 async fn max_agents_limit() {
729 let mut swarm = SwarmManager::new().with_limits(2, 3);
730 let ctx = AgentContext::new();
731
732 let _id1 = swarm
734 .spawn(
735 SpawnConfig::explorer("Task 1"),
736 Box::new(SimpleAgent {}),
737 ToolRegistry::new(),
738 &ctx,
739 )
740 .unwrap();
741
742 let _id2 = swarm
743 .spawn(
744 SpawnConfig::explorer("Task 2"),
745 Box::new(SimpleAgent {}),
746 ToolRegistry::new(),
747 &ctx,
748 )
749 .unwrap();
750
751 let err = swarm
753 .spawn(
754 SpawnConfig::explorer("Task 3"),
755 Box::new(SimpleAgent {}),
756 ToolRegistry::new(),
757 &ctx,
758 )
759 .err()
760 .unwrap();
761 assert!(matches!(err, SwarmError::MaxAgents(2)));
762 }
763
764 #[tokio::test]
765 async fn max_depth_limit() {
766 let mut swarm = SwarmManager::new().with_limits(8, 3).with_depth(3);
767 let ctx = AgentContext::new();
768
769 let err = swarm
770 .spawn(
771 SpawnConfig::explorer("Task"),
772 Box::new(SimpleAgent {}),
773 ToolRegistry::new(),
774 &ctx,
775 )
776 .err()
777 .unwrap();
778 assert!(matches!(err, SwarmError::MaxDepth(3)));
779 }
780
781 #[tokio::test]
782 async fn status_tracking() {
783 let mut swarm = SwarmManager::new();
784 let ctx = AgentContext::new();
785
786 let id = swarm
787 .spawn(
788 SpawnConfig::explorer("Quick task"),
789 Box::new(SimpleAgent {}),
790 ToolRegistry::new(),
791 &ctx,
792 )
793 .unwrap();
794
795 let result = swarm.wait(&id).await.unwrap();
797 assert_eq!(result.status, AgentStatus::Completed);
798
799 assert!(swarm.status(&id).await.is_none());
801 }
802
803 #[tokio::test]
804 async fn wait_all_returns_results() {
805 let mut swarm = SwarmManager::new();
806 let ctx = AgentContext::new();
807
808 let _id1 = swarm
809 .spawn(
810 SpawnConfig::explorer("Task 1"),
811 Box::new(SimpleAgent {}),
812 ToolRegistry::new(),
813 &ctx,
814 )
815 .unwrap();
816
817 let _id2 = swarm
818 .spawn(
819 SpawnConfig::worker("Task 2"),
820 Box::new(SimpleAgent {}),
821 ToolRegistry::new(),
822 &ctx,
823 )
824 .unwrap();
825
826 let results = swarm.wait_all().await;
827 assert_eq!(results.len(), 2);
828 assert!(results.iter().all(|r| r.status == AgentStatus::Completed));
829 }
830
831 #[test]
832 fn agent_role_display() {
833 assert_eq!(AgentRole::Explorer.name(), "explorer");
834 assert_eq!(AgentRole::Worker.name(), "worker");
835 assert_eq!(AgentRole::Reviewer.name(), "reviewer");
836 assert_eq!(AgentRole::Custom("planner".into()).name(), "planner");
837 }
838
839 #[test]
840 fn spawn_config_constructors() {
841 let cfg = SpawnConfig::explorer("Find files");
842 assert_eq!(cfg.role, AgentRole::Explorer);
843 assert_eq!(cfg.max_steps, 10);
844
845 let cfg = SpawnConfig::worker("Implement feature");
846 assert_eq!(cfg.role, AgentRole::Worker);
847 assert_eq!(cfg.max_steps, 30);
848
849 let cfg = SpawnConfig::reviewer("Review code");
850 assert_eq!(cfg.role, AgentRole::Reviewer);
851 assert_eq!(cfg.max_steps, 15);
852 }
853}