1use std::collections::HashMap;
2use std::sync::Arc;
3use std::time::Duration;
4
5use async_trait::async_trait;
6
7use super::TextAgent;
8use crate::error::AgentError;
9use crate::state::State;
10
11#[derive(Clone, Default)]
13pub struct TaskRegistry {
14 pub(crate) inner:
15 Arc<tokio::sync::Mutex<HashMap<String, tokio::task::JoinHandle<Result<String, String>>>>>,
16}
17
18impl TaskRegistry {
19 pub fn new() -> Self {
21 Self::default()
22 }
23}
24
25pub struct DispatchTextAgent {
30 name: String,
31 children: Vec<(String, Arc<dyn TextAgent>)>,
32 registry: TaskRegistry,
33 budget: Arc<tokio::sync::Semaphore>,
34}
35
36impl DispatchTextAgent {
37 pub fn new(
39 name: impl Into<String>,
40 children: Vec<(String, Arc<dyn TextAgent>)>,
41 registry: TaskRegistry,
42 budget: Arc<tokio::sync::Semaphore>,
43 ) -> Self {
44 Self {
45 name: name.into(),
46 children,
47 registry,
48 budget,
49 }
50 }
51}
52
53#[async_trait]
54impl TextAgent for DispatchTextAgent {
55 fn name(&self) -> &str {
56 &self.name
57 }
58
59 async fn run(&self, state: &State) -> Result<String, AgentError> {
60 let mut registry = self.registry.inner.lock().await;
61
62 for (task_name, agent) in &self.children {
63 let agent = agent.clone();
64 let state = state.clone();
65 let budget = self.budget.clone();
66 let task_name_owned = task_name.clone();
67
68 let handle = tokio::spawn(async move {
69 let _permit = budget
70 .acquire()
71 .await
72 .map_err(|e| format!("Semaphore closed: {e}"))?;
73 agent
74 .run(&state)
75 .await
76 .map_err(|e| format!("Task '{}' failed: {}", task_name_owned, e))
77 });
78
79 registry.insert(task_name.clone(), handle);
80 }
81
82 state.set(
83 "_dispatch_status",
84 self.children
85 .iter()
86 .map(|(name, _)| (name.clone(), "running".to_string()))
87 .collect::<HashMap<String, String>>(),
88 );
89
90 Ok(String::new())
91 }
92}
93
94pub struct JoinTextAgent {
98 name: String,
99 registry: TaskRegistry,
100 target_names: Option<Vec<String>>,
101 timeout: Option<Duration>,
102}
103
104impl JoinTextAgent {
105 pub fn new(name: impl Into<String>, registry: TaskRegistry) -> Self {
107 Self {
108 name: name.into(),
109 registry,
110 target_names: None,
111 timeout: None,
112 }
113 }
114
115 pub fn targets(mut self, names: Vec<String>) -> Self {
117 self.target_names = Some(names);
118 self
119 }
120
121 pub fn timeout(mut self, timeout: Duration) -> Self {
123 self.timeout = Some(timeout);
124 self
125 }
126}
127
128#[async_trait]
129impl TextAgent for JoinTextAgent {
130 fn name(&self) -> &str {
131 &self.name
132 }
133
134 async fn run(&self, state: &State) -> Result<String, AgentError> {
135 let mut registry = self.registry.inner.lock().await;
136
137 let tasks: HashMap<String, _> = if let Some(targets) = &self.target_names {
139 targets
140 .iter()
141 .filter_map(|name| registry.remove(name).map(|h| (name.clone(), h)))
142 .collect()
143 } else {
144 std::mem::take(&mut *registry)
145 };
146 drop(registry);
147
148 let mut results = Vec::new();
149
150 for (task_name, handle) in tasks {
151 let result = if let Some(timeout) = self.timeout {
152 match tokio::time::timeout(timeout, handle).await {
153 Ok(Ok(Ok(text))) => {
154 state.set(format!("_result_{}", task_name), &text);
155 Ok(text)
156 }
157 Ok(Ok(Err(e))) => Err(AgentError::Other(e)),
158 Ok(Err(e)) => Err(AgentError::Other(format!("Join error: {e}"))),
159 Err(_) => Err(AgentError::Timeout),
160 }
161 } else {
162 match handle.await {
163 Ok(Ok(text)) => {
164 state.set(format!("_result_{}", task_name), &text);
165 Ok(text)
166 }
167 Ok(Err(e)) => Err(AgentError::Other(e)),
168 Err(e) => Err(AgentError::Other(format!("Join error: {e}"))),
169 }
170 };
171
172 results.push(result?);
173 }
174
175 let combined = results.join("\n");
176 state.set("output", &combined);
177 Ok(combined)
178 }
179}