Skip to main content

pawan/agent/
pool.rs

1use std::sync::{
2    atomic::{AtomicBool, AtomicUsize, Ordering},
3    Arc,
4};
5use std::time::Instant;
6
7use tokio::sync::{mpsc, Semaphore};
8use tokio_util::sync::CancellationToken;
9
10/// Status of a pooled task.
11#[derive(Debug, Clone, Copy, PartialEq, Eq)]
12pub enum TaskStatus {
13    Completed,
14    Failed,
15    Cancelled,
16    TimedOut,
17}
18
19/// Pool event stream for progress reporting.
20#[derive(Debug, Clone)]
21pub enum AgentPoolEvent {
22    Progress { completed: usize, total: usize },
23}
24
25/// Input task for the pool.
26#[derive(Debug, Clone)]
27pub struct PoolTask {
28    pub id: String,
29    pub agent_type: String,
30    pub assignment: String,
31    pub context: Option<String>,
32}
33
34/// Result record for a pooled task.
35#[derive(Debug, Clone)]
36pub struct PoolResult {
37    pub id: String,
38    pub status: TaskStatus, // Completed, Failed, Cancelled, TimedOut
39    pub output: Option<String>,
40    pub error: Option<String>,
41    pub duration_ms: u64,
42}
43
44#[async_trait::async_trait]
45pub trait PoolExecutor: Send + Sync + 'static {
46    async fn execute_task(&self, task: PoolTask, cancel: CancellationToken) -> PoolResult;
47}
48
49/// Run multiple tasks concurrently with a concurrency limit.
50pub struct AgentPool {
51    pub max_concurrent: usize, // default: number of CPU cores
52    pub tasks: Vec<PoolTask>,
53    pub results: Vec<PoolResult>,
54    stop_on_error: bool,
55    cancel: CancellationToken,
56    progress_tx: Option<mpsc::UnboundedSender<AgentPoolEvent>>,
57    executor: Arc<dyn PoolExecutor>,
58}
59
60impl AgentPool {
61    pub fn new(executor: Arc<dyn PoolExecutor>) -> Self {
62        let default_concurrency = std::thread::available_parallelism()
63            .map(|n| n.get())
64            .unwrap_or(1);
65
66        Self {
67            max_concurrent: default_concurrency,
68            tasks: Vec::new(),
69            results: Vec::new(),
70            stop_on_error: false,
71            cancel: CancellationToken::new(),
72            progress_tx: None,
73            executor,
74        }
75    }
76
77    pub fn with_max_concurrent(mut self, max_concurrent: usize) -> Self {
78        self.max_concurrent = max_concurrent.max(1);
79        self
80    }
81
82    pub fn with_stop_on_error(mut self, stop_on_error: bool) -> Self {
83        self.stop_on_error = stop_on_error;
84        self
85    }
86
87    pub fn with_progress_sender(mut self, tx: mpsc::UnboundedSender<AgentPoolEvent>) -> Self {
88        self.progress_tx = Some(tx);
89        self
90    }
91
92    pub fn cancel_token(&self) -> CancellationToken {
93        self.cancel.clone()
94    }
95
96    /// Execute all tasks, returning results in the original order.
97    pub async fn execute(&mut self) -> Vec<PoolResult> {
98        self.results.clear();
99
100        if self.tasks.is_empty() {
101            return Vec::new();
102        }
103
104        let total = self.tasks.len();
105        let semaphore = Arc::new(Semaphore::new(self.max_concurrent.max(1)));
106        let completed = Arc::new(AtomicUsize::new(0));
107        let has_failed = Arc::new(AtomicBool::new(false));
108
109        let mut handles = Vec::with_capacity(total);
110
111        for (idx, task) in self.tasks.clone().into_iter().enumerate() {
112            let sem = semaphore.clone();
113            let executor = self.executor.clone();
114            let cancel = self.cancel.clone();
115            let completed_ctr = completed.clone();
116            let has_failed_flag = has_failed.clone();
117            let stop_on_error = self.stop_on_error;
118            let progress_tx = self.progress_tx.clone();
119
120            let handle = tokio::task::spawn(async move {
121                if cancel.is_cancelled() {
122                    return (
123                        idx,
124                        PoolResult {
125                            id: task.id.clone(),
126                            status: TaskStatus::Cancelled,
127                            output: None,
128                            error: Some("cancelled".to_string()),
129                            duration_ms: 0,
130                        },
131                    );
132                }
133
134                // Acquire concurrency permit (cancellation-aware).
135                let permit = tokio::select! {
136                    _ = cancel.cancelled() => {
137                        return (idx, PoolResult {
138                            id: task.id.clone(),
139                            status: TaskStatus::Cancelled,
140                            output: None,
141                            error: Some("cancelled".to_string()),
142                            duration_ms: 0,
143                        });
144                    }
145                    p = sem.acquire() => p,
146                };
147
148                let _permit = match permit {
149                    Ok(p) => p,
150                    Err(_) => {
151                        return (
152                            idx,
153                            PoolResult {
154                                id: task.id.clone(),
155                                status: TaskStatus::Failed,
156                                output: None,
157                                error: Some("semaphore closed".to_string()),
158                                duration_ms: 0,
159                            },
160                        );
161                    }
162                };
163
164                if cancel.is_cancelled() {
165                    return (
166                        idx,
167                        PoolResult {
168                            id: task.id.clone(),
169                            status: TaskStatus::Cancelled,
170                            output: None,
171                            error: Some("cancelled".to_string()),
172                            duration_ms: 0,
173                        },
174                    );
175                }
176
177                let started = Instant::now();
178                let mut result = executor.execute_task(task, cancel.clone()).await;
179                result.duration_ms = started.elapsed().as_millis() as u64;
180
181                if stop_on_error && result.status == TaskStatus::Failed {
182                    has_failed_flag.store(true, Ordering::SeqCst);
183                    cancel.cancel();
184                }
185
186                let done = completed_ctr.fetch_add(1, Ordering::SeqCst) + 1;
187                if let Some(tx) = progress_tx {
188                    let _ = tx.send(AgentPoolEvent::Progress {
189                        completed: done,
190                        total,
191                    });
192                }
193
194                if stop_on_error
195                    && has_failed_flag.load(Ordering::SeqCst)
196                    && result.status != TaskStatus::Completed
197                    && result.status != TaskStatus::Failed
198                {
199                    result.status = TaskStatus::Cancelled;
200                    result.output = None;
201                    if result.error.is_none() {
202                        result.error = Some("cancelled".to_string());
203                    }
204                }
205
206                (idx, result)
207            });
208
209            handles.push(handle);
210        }
211
212        let mut out: Vec<Option<PoolResult>> = vec![None; total];
213        for h in handles {
214            match h.await {
215                Ok((idx, r)) => out[idx] = Some(r),
216                Err(join_err) => {
217                    let r = PoolResult {
218                        id: "<join>".to_string(),
219                        status: TaskStatus::Failed,
220                        output: None,
221                        error: Some(format!("join error: {join_err}")),
222                        duration_ms: 0,
223                    };
224                    if let Some(slot) = out.iter_mut().find(|s| s.is_none()) {
225                        *slot = Some(r);
226                    }
227                }
228            }
229        }
230
231        let results: Vec<PoolResult> = out
232            .into_iter()
233            .map(|r| {
234                r.unwrap_or(PoolResult {
235                    id: "<missing>".to_string(),
236                    status: TaskStatus::Cancelled,
237                    output: None,
238                    error: Some("missing result".to_string()),
239                    duration_ms: 0,
240                })
241            })
242            .collect();
243
244        self.results = results.clone();
245        results
246    }
247}
248
249#[cfg(test)]
250mod tests {
251    use super::*;
252    use std::time::Duration;
253    use tokio::time::sleep;
254
255    struct TestExecutor;
256
257    #[async_trait::async_trait]
258    impl PoolExecutor for TestExecutor {
259        async fn execute_task(&self, task: PoolTask, cancel: CancellationToken) -> PoolResult {
260            tokio::select! {
261                _ = cancel.cancelled() => {
262                    PoolResult {
263                        id: task.id.clone(),
264                        status: TaskStatus::Cancelled,
265                        output: None,
266                        error: Some("cancelled".to_string()),
267                        duration_ms: 0,
268                    }
269                }
270                _ = sleep(Duration::from_millis(25)) => {
271                    if task.assignment.contains("fail") {
272                        PoolResult {
273                            id: task.id.clone(),
274                            status: TaskStatus::Failed,
275                            output: None,
276                            error: Some("boom".to_string()),
277                            duration_ms: 0,
278                        }
279                    } else {
280                        PoolResult {
281                            id: task.id.clone(),
282                            status: TaskStatus::Completed,
283                            output: Some(format!("ok:{}", task.id)),
284                            error: None,
285                            duration_ms: 0,
286                        }
287                    }
288                }
289            }
290        }
291    }
292
293    #[tokio::test]
294    async fn pool_empty_returns_empty_vec() {
295        let exec = Arc::new(TestExecutor);
296        let mut pool = AgentPool::new(exec);
297        let results = pool.execute().await;
298        assert!(results.is_empty());
299    }
300
301    #[tokio::test]
302    async fn pool_three_tasks_all_complete() {
303        let exec = Arc::new(TestExecutor);
304        let mut pool = AgentPool::new(exec).with_max_concurrent(2);
305        pool.tasks = vec![
306            PoolTask {
307                id: "a".into(),
308                agent_type: "t".into(),
309                assignment: "ok".into(),
310                context: None,
311            },
312            PoolTask {
313                id: "b".into(),
314                agent_type: "t".into(),
315                assignment: "ok".into(),
316                context: None,
317            },
318            PoolTask {
319                id: "c".into(),
320                agent_type: "t".into(),
321                assignment: "ok".into(),
322                context: None,
323            },
324        ];
325
326        let results = pool.execute().await;
327        assert_eq!(results.len(), 3);
328        assert_eq!(results[0].id, "a");
329        assert_eq!(results[1].id, "b");
330        assert_eq!(results[2].id, "c");
331        assert!(results.iter().all(|r| r.status == TaskStatus::Completed));
332    }
333
334    #[tokio::test]
335    async fn pool_stop_on_error_cancels_remaining() {
336        let exec = Arc::new(TestExecutor);
337        let mut pool = AgentPool::new(exec)
338            .with_max_concurrent(3)
339            .with_stop_on_error(true);
340
341        pool.tasks = vec![
342            PoolTask {
343                id: "ok1".into(),
344                agent_type: "t".into(),
345                assignment: "ok".into(),
346                context: None,
347            },
348            PoolTask {
349                id: "bad".into(),
350                agent_type: "t".into(),
351                assignment: "fail".into(),
352                context: None,
353            },
354            PoolTask {
355                id: "ok2".into(),
356                agent_type: "t".into(),
357                assignment: "ok".into(),
358                context: None,
359            },
360        ];
361
362        let results = pool.execute().await;
363        assert_eq!(results.len(), 3);
364        assert!(results
365            .iter()
366            .any(|r| r.id == "bad" && r.status == TaskStatus::Failed));
367        assert!(results
368            .iter()
369            .filter(|r| r.id != "bad")
370            .all(|r| r.status == TaskStatus::Completed || r.status == TaskStatus::Cancelled));
371    }
372}