1use std::sync::{
2 atomic::{AtomicBool, AtomicUsize, Ordering},
3 Arc,
4};
5use std::time::Instant;
6
7use tokio::sync::{mpsc, Semaphore};
8use tokio_util::sync::CancellationToken;
9
10#[derive(Debug, Clone, Copy, PartialEq, Eq)]
12pub enum TaskStatus {
13 Completed,
14 Failed,
15 Cancelled,
16 TimedOut,
17}
18
19#[derive(Debug, Clone)]
21pub enum AgentPoolEvent {
22 Progress { completed: usize, total: usize },
23}
24
25#[derive(Debug, Clone)]
27pub struct PoolTask {
28 pub id: String,
29 pub agent_type: String,
30 pub assignment: String,
31 pub context: Option<String>,
32}
33
34#[derive(Debug, Clone)]
36pub struct PoolResult {
37 pub id: String,
38 pub status: TaskStatus, pub output: Option<String>,
40 pub error: Option<String>,
41 pub duration_ms: u64,
42}
43
44#[async_trait::async_trait]
45pub trait PoolExecutor: Send + Sync + 'static {
46 async fn execute_task(&self, task: PoolTask, cancel: CancellationToken) -> PoolResult;
47}
48
49pub struct AgentPool {
51 pub max_concurrent: usize, pub tasks: Vec<PoolTask>,
53 pub results: Vec<PoolResult>,
54 stop_on_error: bool,
55 cancel: CancellationToken,
56 progress_tx: Option<mpsc::UnboundedSender<AgentPoolEvent>>,
57 executor: Arc<dyn PoolExecutor>,
58}
59
60impl AgentPool {
61 pub fn new(executor: Arc<dyn PoolExecutor>) -> Self {
62 let default_concurrency = std::thread::available_parallelism()
63 .map(|n| n.get())
64 .unwrap_or(1);
65
66 Self {
67 max_concurrent: default_concurrency,
68 tasks: Vec::new(),
69 results: Vec::new(),
70 stop_on_error: false,
71 cancel: CancellationToken::new(),
72 progress_tx: None,
73 executor,
74 }
75 }
76
77 pub fn with_max_concurrent(mut self, max_concurrent: usize) -> Self {
78 self.max_concurrent = max_concurrent.max(1);
79 self
80 }
81
82 pub fn with_stop_on_error(mut self, stop_on_error: bool) -> Self {
83 self.stop_on_error = stop_on_error;
84 self
85 }
86
87 pub fn with_progress_sender(mut self, tx: mpsc::UnboundedSender<AgentPoolEvent>) -> Self {
88 self.progress_tx = Some(tx);
89 self
90 }
91
92 pub fn cancel_token(&self) -> CancellationToken {
93 self.cancel.clone()
94 }
95
96 pub async fn execute(&mut self) -> Vec<PoolResult> {
98 self.results.clear();
99
100 if self.tasks.is_empty() {
101 return Vec::new();
102 }
103
104 let total = self.tasks.len();
105 let semaphore = Arc::new(Semaphore::new(self.max_concurrent.max(1)));
106 let completed = Arc::new(AtomicUsize::new(0));
107 let has_failed = Arc::new(AtomicBool::new(false));
108
109 let mut handles = Vec::with_capacity(total);
110
111 for (idx, task) in self.tasks.clone().into_iter().enumerate() {
112 let sem = semaphore.clone();
113 let executor = self.executor.clone();
114 let cancel = self.cancel.clone();
115 let completed_ctr = completed.clone();
116 let has_failed_flag = has_failed.clone();
117 let stop_on_error = self.stop_on_error;
118 let progress_tx = self.progress_tx.clone();
119
120 let handle = tokio::task::spawn(async move {
121 if cancel.is_cancelled() {
122 return (
123 idx,
124 PoolResult {
125 id: task.id.clone(),
126 status: TaskStatus::Cancelled,
127 output: None,
128 error: Some("cancelled".to_string()),
129 duration_ms: 0,
130 },
131 );
132 }
133
134 let permit = tokio::select! {
136 _ = cancel.cancelled() => {
137 return (idx, PoolResult {
138 id: task.id.clone(),
139 status: TaskStatus::Cancelled,
140 output: None,
141 error: Some("cancelled".to_string()),
142 duration_ms: 0,
143 });
144 }
145 p = sem.acquire() => p,
146 };
147
148 let _permit = match permit {
149 Ok(p) => p,
150 Err(_) => {
151 return (
152 idx,
153 PoolResult {
154 id: task.id.clone(),
155 status: TaskStatus::Failed,
156 output: None,
157 error: Some("semaphore closed".to_string()),
158 duration_ms: 0,
159 },
160 );
161 }
162 };
163
164 if cancel.is_cancelled() {
165 return (
166 idx,
167 PoolResult {
168 id: task.id.clone(),
169 status: TaskStatus::Cancelled,
170 output: None,
171 error: Some("cancelled".to_string()),
172 duration_ms: 0,
173 },
174 );
175 }
176
177 let started = Instant::now();
178 let mut result = executor.execute_task(task, cancel.clone()).await;
179 result.duration_ms = started.elapsed().as_millis() as u64;
180
181 if stop_on_error && result.status == TaskStatus::Failed {
182 has_failed_flag.store(true, Ordering::SeqCst);
183 cancel.cancel();
184 }
185
186 let done = completed_ctr.fetch_add(1, Ordering::SeqCst) + 1;
187 if let Some(tx) = progress_tx {
188 let _ = tx.send(AgentPoolEvent::Progress {
189 completed: done,
190 total,
191 });
192 }
193
194 if stop_on_error
195 && has_failed_flag.load(Ordering::SeqCst)
196 && result.status != TaskStatus::Completed
197 && result.status != TaskStatus::Failed
198 {
199 result.status = TaskStatus::Cancelled;
200 result.output = None;
201 if result.error.is_none() {
202 result.error = Some("cancelled".to_string());
203 }
204 }
205
206 (idx, result)
207 });
208
209 handles.push(handle);
210 }
211
212 let mut out: Vec<Option<PoolResult>> = vec![None; total];
213 for h in handles {
214 match h.await {
215 Ok((idx, r)) => out[idx] = Some(r),
216 Err(join_err) => {
217 let r = PoolResult {
218 id: "<join>".to_string(),
219 status: TaskStatus::Failed,
220 output: None,
221 error: Some(format!("join error: {join_err}")),
222 duration_ms: 0,
223 };
224 if let Some(slot) = out.iter_mut().find(|s| s.is_none()) {
225 *slot = Some(r);
226 }
227 }
228 }
229 }
230
231 let results: Vec<PoolResult> = out
232 .into_iter()
233 .map(|r| {
234 r.unwrap_or(PoolResult {
235 id: "<missing>".to_string(),
236 status: TaskStatus::Cancelled,
237 output: None,
238 error: Some("missing result".to_string()),
239 duration_ms: 0,
240 })
241 })
242 .collect();
243
244 self.results = results.clone();
245 results
246 }
247}
248
249#[cfg(test)]
250mod tests {
251 use super::*;
252 use std::time::Duration;
253 use tokio::time::sleep;
254
255 struct TestExecutor;
256
257 #[async_trait::async_trait]
258 impl PoolExecutor for TestExecutor {
259 async fn execute_task(&self, task: PoolTask, cancel: CancellationToken) -> PoolResult {
260 tokio::select! {
261 _ = cancel.cancelled() => {
262 PoolResult {
263 id: task.id.clone(),
264 status: TaskStatus::Cancelled,
265 output: None,
266 error: Some("cancelled".to_string()),
267 duration_ms: 0,
268 }
269 }
270 _ = sleep(Duration::from_millis(25)) => {
271 if task.assignment.contains("fail") {
272 PoolResult {
273 id: task.id.clone(),
274 status: TaskStatus::Failed,
275 output: None,
276 error: Some("boom".to_string()),
277 duration_ms: 0,
278 }
279 } else {
280 PoolResult {
281 id: task.id.clone(),
282 status: TaskStatus::Completed,
283 output: Some(format!("ok:{}", task.id)),
284 error: None,
285 duration_ms: 0,
286 }
287 }
288 }
289 }
290 }
291 }
292
293 #[tokio::test]
294 async fn pool_empty_returns_empty_vec() {
295 let exec = Arc::new(TestExecutor);
296 let mut pool = AgentPool::new(exec);
297 let results = pool.execute().await;
298 assert!(results.is_empty());
299 }
300
301 #[tokio::test]
302 async fn pool_three_tasks_all_complete() {
303 let exec = Arc::new(TestExecutor);
304 let mut pool = AgentPool::new(exec).with_max_concurrent(2);
305 pool.tasks = vec![
306 PoolTask {
307 id: "a".into(),
308 agent_type: "t".into(),
309 assignment: "ok".into(),
310 context: None,
311 },
312 PoolTask {
313 id: "b".into(),
314 agent_type: "t".into(),
315 assignment: "ok".into(),
316 context: None,
317 },
318 PoolTask {
319 id: "c".into(),
320 agent_type: "t".into(),
321 assignment: "ok".into(),
322 context: None,
323 },
324 ];
325
326 let results = pool.execute().await;
327 assert_eq!(results.len(), 3);
328 assert_eq!(results[0].id, "a");
329 assert_eq!(results[1].id, "b");
330 assert_eq!(results[2].id, "c");
331 assert!(results.iter().all(|r| r.status == TaskStatus::Completed));
332 }
333
334 #[tokio::test]
335 async fn pool_stop_on_error_cancels_remaining() {
336 let exec = Arc::new(TestExecutor);
337 let mut pool = AgentPool::new(exec)
338 .with_max_concurrent(3)
339 .with_stop_on_error(true);
340
341 pool.tasks = vec![
342 PoolTask {
343 id: "ok1".into(),
344 agent_type: "t".into(),
345 assignment: "ok".into(),
346 context: None,
347 },
348 PoolTask {
349 id: "bad".into(),
350 agent_type: "t".into(),
351 assignment: "fail".into(),
352 context: None,
353 },
354 PoolTask {
355 id: "ok2".into(),
356 agent_type: "t".into(),
357 assignment: "ok".into(),
358 context: None,
359 },
360 ];
361
362 let results = pool.execute().await;
363 assert_eq!(results.len(), 3);
364 assert!(results
365 .iter()
366 .any(|r| r.id == "bad" && r.status == TaskStatus::Failed));
367 assert!(results
368 .iter()
369 .filter(|r| r.id != "bad")
370 .all(|r| r.status == TaskStatus::Completed || r.status == TaskStatus::Cancelled));
371 }
372}