Skip to main content

brainwires_agents/
pool.rs

1//! Agent Pool - Manages a pool of background task agents
2//!
3//! [`AgentPool`] handles the lifecycle of [`TaskAgent`]s: spawning, monitoring,
4//! stopping, and awaiting results. All agents in the pool share the same
5//! [`Provider`], tool executor, communication hub, and file lock manager.
6//!
7//! ## Usage
8//!
9//! ```rust,ignore
10//! use std::sync::Arc;
11//! use brainwires_agents::{AgentPool, TaskAgentConfig};
12//! use brainwires_core::Task;
13//!
14//! let pool = AgentPool::new(
15//!     10,
16//!     Arc::clone(&provider),
17//!     Arc::clone(&tool_executor),
18//!     Arc::clone(&hub),
19//!     Arc::clone(&lock_manager),
20//!     "/my/project".to_string(),
21//! );
22//!
23//! let agent_id = pool.spawn_agent(
24//!     Task::new("t-1", "Implement feature X"),
25//!     None,
26//! ).await?;
27//!
28//! let result = pool.await_completion(&agent_id).await?;
29//! ```
30
31use std::collections::HashMap;
32use std::sync::Arc;
33
34use anyhow::{Result, anyhow};
35use tokio::sync::RwLock;
36use tokio::task::JoinHandle;
37
38use brainwires_core::{Provider, Task};
39use brainwires_tool_system::ToolExecutor;
40
41use crate::communication::CommunicationHub;
42use crate::context::AgentContext;
43use crate::file_locks::FileLockManager;
44use crate::task_agent::{
45    TaskAgent, TaskAgentConfig, TaskAgentResult, TaskAgentStatus, spawn_task_agent,
46};
47
48// ── Internal handle ────────────────────────────────────────────────────────
49
50struct AgentHandle {
51    agent: Arc<TaskAgent>,
52    join_handle: JoinHandle<Result<TaskAgentResult>>,
53}
54
55// ── Public API ─────────────────────────────────────────────────────────────
56
57/// Manages a pool of background [`TaskAgent`]s.
58///
59/// All agents share the same provider, tool executor, communication hub,
60/// file lock manager, and working directory. Each agent gets its own
61/// conversation history and working set.
62pub struct AgentPool {
63    max_agents: usize,
64    agents: Arc<RwLock<HashMap<String, AgentHandle>>>,
65    communication_hub: Arc<CommunicationHub>,
66    file_lock_manager: Arc<FileLockManager>,
67    provider: Arc<dyn Provider>,
68    tool_executor: Arc<dyn ToolExecutor>,
69    working_directory: String,
70}
71
72impl AgentPool {
73    /// Create a new agent pool.
74    ///
75    /// # Parameters
76    /// - `max_agents`: maximum number of concurrently running agents.
77    /// - `provider`: AI provider shared by all agents.
78    /// - `tool_executor`: tool executor shared by all agents.
79    /// - `communication_hub`: inter-agent message bus.
80    /// - `file_lock_manager`: file coordination across agents.
81    /// - `working_directory`: default working directory for spawned agents.
82    pub fn new(
83        max_agents: usize,
84        provider: Arc<dyn Provider>,
85        tool_executor: Arc<dyn ToolExecutor>,
86        communication_hub: Arc<CommunicationHub>,
87        file_lock_manager: Arc<FileLockManager>,
88        working_directory: impl Into<String>,
89    ) -> Self {
90        Self {
91            max_agents,
92            agents: Arc::new(RwLock::new(HashMap::new())),
93            communication_hub,
94            file_lock_manager,
95            provider,
96            tool_executor,
97            working_directory: working_directory.into(),
98        }
99    }
100
101    /// Spawn a new task agent and start it on a Tokio background task.
102    ///
103    /// Returns the agent ID. Use [`await_completion`][Self::await_completion]
104    /// to wait for the result.
105    ///
106    /// Returns an error if the pool is already at capacity.
107    pub async fn spawn_agent(&self, task: Task, config: Option<TaskAgentConfig>) -> Result<String> {
108        {
109            let agents = self.agents.read().await;
110            if agents.len() >= self.max_agents {
111                return Err(anyhow!(
112                    "Agent pool is full ({}/{})",
113                    agents.len(),
114                    self.max_agents
115                ));
116            }
117        }
118
119        let agent_id = format!("agent-{}", uuid::Uuid::new_v4());
120        let config = config.unwrap_or_default();
121
122        let context = Arc::new(AgentContext::new(
123            self.working_directory.clone(),
124            Arc::clone(&self.tool_executor),
125            Arc::clone(&self.communication_hub),
126            Arc::clone(&self.file_lock_manager),
127        ));
128
129        let agent = Arc::new(TaskAgent::new(
130            agent_id.clone(),
131            task,
132            Arc::clone(&self.provider),
133            context,
134            config,
135        ));
136
137        let handle = spawn_task_agent(Arc::clone(&agent));
138
139        self.agents.write().await.insert(
140            agent_id.clone(),
141            AgentHandle {
142                agent,
143                join_handle: handle,
144            },
145        );
146
147        tracing::info!(agent_id = %agent_id, "spawned agent");
148        Ok(agent_id)
149    }
150
151    /// Spawn a new task agent with a custom [`AgentContext`].
152    ///
153    /// Unlike [`spawn_agent`][Self::spawn_agent] which uses the pool's default
154    /// working directory, this method accepts a pre-built context. This is
155    /// useful for workers that run in isolated worktrees with per-agent
156    /// working directories.
157    ///
158    /// Returns the agent ID.
159    pub async fn spawn_agent_with_context(
160        &self,
161        task: Task,
162        context: Arc<AgentContext>,
163        config: Option<TaskAgentConfig>,
164    ) -> Result<String> {
165        {
166            let agents = self.agents.read().await;
167            if agents.len() >= self.max_agents {
168                return Err(anyhow!(
169                    "Agent pool is full ({}/{})",
170                    agents.len(),
171                    self.max_agents
172                ));
173            }
174        }
175
176        let agent_id = format!("agent-{}", uuid::Uuid::new_v4());
177        let config = config.unwrap_or_default();
178
179        let agent = Arc::new(TaskAgent::new(
180            agent_id.clone(),
181            task,
182            Arc::clone(&self.provider),
183            context,
184            config,
185        ));
186
187        let handle = spawn_task_agent(Arc::clone(&agent));
188
189        self.agents.write().await.insert(
190            agent_id.clone(),
191            AgentHandle {
192                agent,
193                join_handle: handle,
194            },
195        );
196
197        tracing::info!(agent_id = %agent_id, "spawned agent with custom context");
198        Ok(agent_id)
199    }
200
201    /// Get the current status of an agent.
202    ///
203    /// Returns `None` if the agent is not in the pool.
204    pub async fn get_status(&self, agent_id: &str) -> Option<TaskAgentStatus> {
205        let agents = self.agents.read().await;
206        let handle = agents.get(agent_id)?;
207        Some(handle.agent.status().await)
208    }
209
210    /// Get a snapshot of the task assigned to an agent.
211    pub async fn get_task(&self, agent_id: &str) -> Option<Task> {
212        let agents = self.agents.read().await;
213        let handle = agents.get(agent_id)?;
214        Some(handle.agent.task().await)
215    }
216
217    /// Abort an agent and remove it from the pool.
218    ///
219    /// File locks held by the agent are released immediately.
220    pub async fn stop_agent(&self, agent_id: &str) -> Result<()> {
221        let handle = self
222            .agents
223            .write()
224            .await
225            .remove(agent_id)
226            .ok_or_else(|| anyhow!("Agent {} not found", agent_id))?;
227
228        handle.join_handle.abort();
229        self.file_lock_manager.release_all_locks(agent_id).await;
230        tracing::info!(agent_id = %agent_id, "stopped agent");
231        Ok(())
232    }
233
234    /// Wait for an agent to finish and return its result.
235    ///
236    /// The agent is removed from the pool once it completes.
237    pub async fn await_completion(&self, agent_id: &str) -> Result<TaskAgentResult> {
238        let handle = self.agents.write().await.remove(agent_id);
239
240        match handle {
241            Some(h) => match h.join_handle.await {
242                Ok(result) => result,
243                Err(e) => Err(anyhow!("Agent task panicked: {}", e)),
244            },
245            None => Err(anyhow!("Agent {} not found", agent_id)),
246        }
247    }
248
249    /// List all agents currently in the pool with their status.
250    pub async fn list_active(&self) -> Vec<(String, TaskAgentStatus)> {
251        let agents = self.agents.read().await;
252        let mut out = Vec::with_capacity(agents.len());
253        for (id, handle) in agents.iter() {
254            out.push((id.clone(), handle.agent.status().await));
255        }
256        out
257    }
258
259    /// Number of agents currently in the pool (running or pending cleanup).
260    pub async fn active_count(&self) -> usize {
261        self.agents.read().await.len()
262    }
263
264    /// Returns `true` if the agent is still running (join handle not finished).
265    pub async fn is_running(&self, agent_id: &str) -> bool {
266        let agents = self.agents.read().await;
267        agents
268            .get(agent_id)
269            .map(|h| !h.join_handle.is_finished())
270            .unwrap_or(false)
271    }
272
273    /// Remove all finished agents from the pool and return their results.
274    pub async fn cleanup_completed(&self) -> Vec<(String, Result<TaskAgentResult>)> {
275        let finished: Vec<String> = {
276            let agents = self.agents.read().await;
277            agents
278                .iter()
279                .filter(|(_, h)| h.join_handle.is_finished())
280                .map(|(id, _)| id.clone())
281                .collect()
282        };
283
284        let mut results = Vec::new();
285        let mut agents = self.agents.write().await;
286        for id in finished {
287            if let Some(handle) = agents.remove(&id) {
288                let result = match handle.join_handle.await {
289                    Ok(r) => r,
290                    Err(e) => Err(anyhow!("Agent task panicked: {}", e)),
291                };
292                results.push((id, result));
293            }
294        }
295        results
296    }
297
298    /// Wait for every agent in the pool to finish.
299    pub async fn await_all(&self) -> Vec<(String, Result<TaskAgentResult>)> {
300        let ids: Vec<String> = self.agents.read().await.keys().cloned().collect();
301        let mut results = Vec::new();
302        for id in ids {
303            results.push((id.clone(), self.await_completion(&id).await));
304        }
305        results
306    }
307
308    /// Abort all agents and clear the pool.
309    pub async fn shutdown(&self) {
310        let mut agents = self.agents.write().await;
311        for (agent_id, handle) in agents.drain() {
312            handle.join_handle.abort();
313            self.file_lock_manager.release_all_locks(&agent_id).await;
314        }
315        tracing::info!("agent pool shut down");
316    }
317
318    /// Get a statistical snapshot of the pool.
319    pub async fn stats(&self) -> AgentPoolStats {
320        let agents = self.agents.read().await;
321        let mut running = 0usize;
322        let mut completed = 0usize;
323
324        for (_, handle) in agents.iter() {
325            if handle.join_handle.is_finished() {
326                completed += 1;
327            } else {
328                running += 1;
329            }
330        }
331
332        AgentPoolStats {
333            max_agents: self.max_agents,
334            total_agents: agents.len(),
335            running,
336            completed,
337            failed: 0, // Not distinguishable without awaiting the handle.
338        }
339    }
340
341    /// Get the shared file lock manager.
342    pub fn file_lock_manager(&self) -> Arc<FileLockManager> {
343        Arc::clone(&self.file_lock_manager)
344    }
345
346    /// Get the shared communication hub.
347    pub fn communication_hub(&self) -> Arc<CommunicationHub> {
348        Arc::clone(&self.communication_hub)
349    }
350}
351
352/// Statistics about the agent pool.
353#[derive(Debug, Clone)]
354pub struct AgentPoolStats {
355    /// Maximum concurrent agents allowed.
356    pub max_agents: usize,
357    /// Total agents currently tracked (running + awaiting cleanup).
358    pub total_agents: usize,
359    /// Agents that are currently running.
360    pub running: usize,
361    /// Agents that have finished but not yet cleaned up.
362    pub completed: usize,
363    /// Agents that are known to have failed (requires awaiting the handle).
364    pub failed: usize,
365}
366
367#[cfg(test)]
368mod tests {
369    use super::*;
370    use crate::communication::CommunicationHub;
371    use crate::file_locks::FileLockManager;
372    use async_trait::async_trait;
373    use brainwires_core::{
374        ChatOptions, ChatResponse, Message, StreamChunk, Tool, ToolContext, ToolResult, ToolUse,
375        Usage,
376    };
377    use brainwires_tool_system::ToolExecutor;
378    use futures::stream::BoxStream;
379
380    struct MockProvider(ChatResponse);
381
382    impl MockProvider {
383        fn done(text: &str) -> Self {
384            Self(ChatResponse {
385                message: Message::assistant(text),
386                finish_reason: Some("stop".to_string()),
387                usage: Usage::default(),
388            })
389        }
390    }
391
392    #[async_trait]
393    impl Provider for MockProvider {
394        fn name(&self) -> &str {
395            "mock"
396        }
397
398        async fn chat(
399            &self,
400            _: &[Message],
401            _: Option<&[Tool]>,
402            _: &ChatOptions,
403        ) -> Result<ChatResponse> {
404            Ok(self.0.clone())
405        }
406
407        fn stream_chat<'a>(
408            &'a self,
409            _: &'a [Message],
410            _: Option<&'a [Tool]>,
411            _: &'a ChatOptions,
412        ) -> BoxStream<'a, Result<StreamChunk>> {
413            Box::pin(futures::stream::empty())
414        }
415    }
416
417    struct NoOpExecutor;
418
419    #[async_trait]
420    impl ToolExecutor for NoOpExecutor {
421        async fn execute(&self, tu: &ToolUse, _: &ToolContext) -> Result<ToolResult> {
422            Ok(ToolResult::success(tu.id.clone(), "ok".to_string()))
423        }
424
425        fn available_tools(&self) -> Vec<Tool> {
426            vec![]
427        }
428    }
429
430    fn make_pool(max: usize) -> AgentPool {
431        AgentPool::new(
432            max,
433            Arc::new(MockProvider::done("Done")),
434            Arc::new(NoOpExecutor),
435            Arc::new(CommunicationHub::new()),
436            Arc::new(FileLockManager::new()),
437            "/tmp",
438        )
439    }
440
441    #[tokio::test]
442    async fn test_pool_creation() {
443        let pool = make_pool(5);
444        assert_eq!(pool.active_count().await, 0);
445    }
446
447    #[tokio::test]
448    async fn test_spawn_and_count() {
449        let pool = make_pool(5);
450        let _ = pool
451            .spawn_agent(
452                Task::new("t-1", "Test"),
453                Some(TaskAgentConfig {
454                    validation_config: None,
455                    ..Default::default()
456                }),
457            )
458            .await
459            .unwrap();
460        assert_eq!(pool.active_count().await, 1);
461    }
462
463    #[tokio::test]
464    async fn test_max_agents_limit() {
465        let pool = make_pool(2);
466        let cfg = || {
467            Some(TaskAgentConfig {
468                validation_config: None,
469                ..Default::default()
470            })
471        };
472
473        pool.spawn_agent(Task::new("t-1", "T1"), cfg())
474            .await
475            .unwrap();
476        pool.spawn_agent(Task::new("t-2", "T2"), cfg())
477            .await
478            .unwrap();
479
480        let err = pool.spawn_agent(Task::new("t-3", "T3"), cfg()).await;
481        assert!(err.is_err());
482        assert!(err.unwrap_err().to_string().contains("full"));
483    }
484
485    #[tokio::test]
486    async fn test_await_completion() {
487        let pool = make_pool(5);
488        let id = pool
489            .spawn_agent(
490                Task::new("t-1", "Finish me"),
491                Some(TaskAgentConfig {
492                    validation_config: None,
493                    ..Default::default()
494                }),
495            )
496            .await
497            .unwrap();
498
499        let result = pool.await_completion(&id).await.unwrap();
500        assert!(result.success);
501        assert_eq!(result.task_id, "t-1");
502    }
503
504    #[tokio::test]
505    async fn test_stop_agent() {
506        let pool = make_pool(5);
507        let id = pool.spawn_agent(Task::new("t-1", "T"), None).await.unwrap();
508
509        pool.stop_agent(&id).await.unwrap();
510        assert_eq!(pool.active_count().await, 0);
511    }
512
513    #[tokio::test]
514    async fn test_shutdown() {
515        let pool = make_pool(5);
516        pool.spawn_agent(Task::new("t-1", "T1"), None)
517            .await
518            .unwrap();
519        pool.spawn_agent(Task::new("t-2", "T2"), None)
520            .await
521            .unwrap();
522
523        pool.shutdown().await;
524        assert_eq!(pool.active_count().await, 0);
525    }
526
527    #[tokio::test]
528    async fn test_stats() {
529        let pool = make_pool(10);
530        let stats = pool.stats().await;
531        assert_eq!(stats.max_agents, 10);
532        assert_eq!(stats.total_agents, 0);
533    }
534
535    #[tokio::test]
536    async fn test_list_active() {
537        let pool = make_pool(5);
538        pool.spawn_agent(Task::new("t-1", "T1"), None)
539            .await
540            .unwrap();
541        pool.spawn_agent(Task::new("t-2", "T2"), None)
542            .await
543            .unwrap();
544
545        let active = pool.list_active().await;
546        assert_eq!(active.len(), 2);
547    }
548}