1use std::collections::HashMap;
32use std::sync::Arc;
33
34use anyhow::{Result, anyhow};
35use tokio::sync::RwLock;
36use tokio::task::JoinHandle;
37
38use brainwires_core::{Provider, Task};
39use brainwires_tool_system::ToolExecutor;
40
41use crate::communication::CommunicationHub;
42use crate::context::AgentContext;
43use crate::file_locks::FileLockManager;
44use crate::task_agent::{
45 TaskAgent, TaskAgentConfig, TaskAgentResult, TaskAgentStatus, spawn_task_agent,
46};
47
48struct AgentHandle {
51 agent: Arc<TaskAgent>,
52 join_handle: JoinHandle<Result<TaskAgentResult>>,
53}
54
55pub struct AgentPool {
63 max_agents: usize,
64 agents: Arc<RwLock<HashMap<String, AgentHandle>>>,
65 communication_hub: Arc<CommunicationHub>,
66 file_lock_manager: Arc<FileLockManager>,
67 provider: Arc<dyn Provider>,
68 tool_executor: Arc<dyn ToolExecutor>,
69 working_directory: String,
70}
71
72impl AgentPool {
73 pub fn new(
83 max_agents: usize,
84 provider: Arc<dyn Provider>,
85 tool_executor: Arc<dyn ToolExecutor>,
86 communication_hub: Arc<CommunicationHub>,
87 file_lock_manager: Arc<FileLockManager>,
88 working_directory: impl Into<String>,
89 ) -> Self {
90 Self {
91 max_agents,
92 agents: Arc::new(RwLock::new(HashMap::new())),
93 communication_hub,
94 file_lock_manager,
95 provider,
96 tool_executor,
97 working_directory: working_directory.into(),
98 }
99 }
100
101 pub async fn spawn_agent(&self, task: Task, config: Option<TaskAgentConfig>) -> Result<String> {
108 {
109 let agents = self.agents.read().await;
110 if agents.len() >= self.max_agents {
111 return Err(anyhow!(
112 "Agent pool is full ({}/{})",
113 agents.len(),
114 self.max_agents
115 ));
116 }
117 }
118
119 let agent_id = format!("agent-{}", uuid::Uuid::new_v4());
120 let config = config.unwrap_or_default();
121
122 let context = Arc::new(AgentContext::new(
123 self.working_directory.clone(),
124 Arc::clone(&self.tool_executor),
125 Arc::clone(&self.communication_hub),
126 Arc::clone(&self.file_lock_manager),
127 ));
128
129 let agent = Arc::new(TaskAgent::new(
130 agent_id.clone(),
131 task,
132 Arc::clone(&self.provider),
133 context,
134 config,
135 ));
136
137 let handle = spawn_task_agent(Arc::clone(&agent));
138
139 self.agents.write().await.insert(
140 agent_id.clone(),
141 AgentHandle {
142 agent,
143 join_handle: handle,
144 },
145 );
146
147 tracing::info!(agent_id = %agent_id, "spawned agent");
148 Ok(agent_id)
149 }
150
151 pub async fn spawn_agent_with_context(
160 &self,
161 task: Task,
162 context: Arc<AgentContext>,
163 config: Option<TaskAgentConfig>,
164 ) -> Result<String> {
165 {
166 let agents = self.agents.read().await;
167 if agents.len() >= self.max_agents {
168 return Err(anyhow!(
169 "Agent pool is full ({}/{})",
170 agents.len(),
171 self.max_agents
172 ));
173 }
174 }
175
176 let agent_id = format!("agent-{}", uuid::Uuid::new_v4());
177 let config = config.unwrap_or_default();
178
179 let agent = Arc::new(TaskAgent::new(
180 agent_id.clone(),
181 task,
182 Arc::clone(&self.provider),
183 context,
184 config,
185 ));
186
187 let handle = spawn_task_agent(Arc::clone(&agent));
188
189 self.agents.write().await.insert(
190 agent_id.clone(),
191 AgentHandle {
192 agent,
193 join_handle: handle,
194 },
195 );
196
197 tracing::info!(agent_id = %agent_id, "spawned agent with custom context");
198 Ok(agent_id)
199 }
200
201 pub async fn get_status(&self, agent_id: &str) -> Option<TaskAgentStatus> {
205 let agents = self.agents.read().await;
206 let handle = agents.get(agent_id)?;
207 Some(handle.agent.status().await)
208 }
209
210 pub async fn get_task(&self, agent_id: &str) -> Option<Task> {
212 let agents = self.agents.read().await;
213 let handle = agents.get(agent_id)?;
214 Some(handle.agent.task().await)
215 }
216
217 pub async fn stop_agent(&self, agent_id: &str) -> Result<()> {
221 let handle = self
222 .agents
223 .write()
224 .await
225 .remove(agent_id)
226 .ok_or_else(|| anyhow!("Agent {} not found", agent_id))?;
227
228 handle.join_handle.abort();
229 self.file_lock_manager.release_all_locks(agent_id).await;
230 tracing::info!(agent_id = %agent_id, "stopped agent");
231 Ok(())
232 }
233
234 pub async fn await_completion(&self, agent_id: &str) -> Result<TaskAgentResult> {
238 let handle = self.agents.write().await.remove(agent_id);
239
240 match handle {
241 Some(h) => match h.join_handle.await {
242 Ok(result) => result,
243 Err(e) => Err(anyhow!("Agent task panicked: {}", e)),
244 },
245 None => Err(anyhow!("Agent {} not found", agent_id)),
246 }
247 }
248
249 pub async fn list_active(&self) -> Vec<(String, TaskAgentStatus)> {
251 let agents = self.agents.read().await;
252 let mut out = Vec::with_capacity(agents.len());
253 for (id, handle) in agents.iter() {
254 out.push((id.clone(), handle.agent.status().await));
255 }
256 out
257 }
258
259 pub async fn active_count(&self) -> usize {
261 self.agents.read().await.len()
262 }
263
264 pub async fn is_running(&self, agent_id: &str) -> bool {
266 let agents = self.agents.read().await;
267 agents
268 .get(agent_id)
269 .map(|h| !h.join_handle.is_finished())
270 .unwrap_or(false)
271 }
272
273 pub async fn cleanup_completed(&self) -> Vec<(String, Result<TaskAgentResult>)> {
275 let finished: Vec<String> = {
276 let agents = self.agents.read().await;
277 agents
278 .iter()
279 .filter(|(_, h)| h.join_handle.is_finished())
280 .map(|(id, _)| id.clone())
281 .collect()
282 };
283
284 let mut results = Vec::new();
285 let mut agents = self.agents.write().await;
286 for id in finished {
287 if let Some(handle) = agents.remove(&id) {
288 let result = match handle.join_handle.await {
289 Ok(r) => r,
290 Err(e) => Err(anyhow!("Agent task panicked: {}", e)),
291 };
292 results.push((id, result));
293 }
294 }
295 results
296 }
297
298 pub async fn await_all(&self) -> Vec<(String, Result<TaskAgentResult>)> {
300 let ids: Vec<String> = self.agents.read().await.keys().cloned().collect();
301 let mut results = Vec::new();
302 for id in ids {
303 results.push((id.clone(), self.await_completion(&id).await));
304 }
305 results
306 }
307
308 pub async fn shutdown(&self) {
310 let mut agents = self.agents.write().await;
311 for (agent_id, handle) in agents.drain() {
312 handle.join_handle.abort();
313 self.file_lock_manager.release_all_locks(&agent_id).await;
314 }
315 tracing::info!("agent pool shut down");
316 }
317
318 pub async fn stats(&self) -> AgentPoolStats {
320 let agents = self.agents.read().await;
321 let mut running = 0usize;
322 let mut completed = 0usize;
323
324 for (_, handle) in agents.iter() {
325 if handle.join_handle.is_finished() {
326 completed += 1;
327 } else {
328 running += 1;
329 }
330 }
331
332 AgentPoolStats {
333 max_agents: self.max_agents,
334 total_agents: agents.len(),
335 running,
336 completed,
337 failed: 0, }
339 }
340
341 pub fn file_lock_manager(&self) -> Arc<FileLockManager> {
343 Arc::clone(&self.file_lock_manager)
344 }
345
346 pub fn communication_hub(&self) -> Arc<CommunicationHub> {
348 Arc::clone(&self.communication_hub)
349 }
350}
351
352#[derive(Debug, Clone)]
354pub struct AgentPoolStats {
355 pub max_agents: usize,
357 pub total_agents: usize,
359 pub running: usize,
361 pub completed: usize,
363 pub failed: usize,
365}
366
367#[cfg(test)]
368mod tests {
369 use super::*;
370 use crate::communication::CommunicationHub;
371 use crate::file_locks::FileLockManager;
372 use async_trait::async_trait;
373 use brainwires_core::{
374 ChatOptions, ChatResponse, Message, StreamChunk, Tool, ToolContext, ToolResult, ToolUse,
375 Usage,
376 };
377 use brainwires_tool_system::ToolExecutor;
378 use futures::stream::BoxStream;
379
380 struct MockProvider(ChatResponse);
381
382 impl MockProvider {
383 fn done(text: &str) -> Self {
384 Self(ChatResponse {
385 message: Message::assistant(text),
386 finish_reason: Some("stop".to_string()),
387 usage: Usage::default(),
388 })
389 }
390 }
391
392 #[async_trait]
393 impl Provider for MockProvider {
394 fn name(&self) -> &str {
395 "mock"
396 }
397
398 async fn chat(
399 &self,
400 _: &[Message],
401 _: Option<&[Tool]>,
402 _: &ChatOptions,
403 ) -> Result<ChatResponse> {
404 Ok(self.0.clone())
405 }
406
407 fn stream_chat<'a>(
408 &'a self,
409 _: &'a [Message],
410 _: Option<&'a [Tool]>,
411 _: &'a ChatOptions,
412 ) -> BoxStream<'a, Result<StreamChunk>> {
413 Box::pin(futures::stream::empty())
414 }
415 }
416
417 struct NoOpExecutor;
418
419 #[async_trait]
420 impl ToolExecutor for NoOpExecutor {
421 async fn execute(&self, tu: &ToolUse, _: &ToolContext) -> Result<ToolResult> {
422 Ok(ToolResult::success(tu.id.clone(), "ok".to_string()))
423 }
424
425 fn available_tools(&self) -> Vec<Tool> {
426 vec![]
427 }
428 }
429
430 fn make_pool(max: usize) -> AgentPool {
431 AgentPool::new(
432 max,
433 Arc::new(MockProvider::done("Done")),
434 Arc::new(NoOpExecutor),
435 Arc::new(CommunicationHub::new()),
436 Arc::new(FileLockManager::new()),
437 "/tmp",
438 )
439 }
440
441 #[tokio::test]
442 async fn test_pool_creation() {
443 let pool = make_pool(5);
444 assert_eq!(pool.active_count().await, 0);
445 }
446
447 #[tokio::test]
448 async fn test_spawn_and_count() {
449 let pool = make_pool(5);
450 let _ = pool
451 .spawn_agent(
452 Task::new("t-1", "Test"),
453 Some(TaskAgentConfig {
454 validation_config: None,
455 ..Default::default()
456 }),
457 )
458 .await
459 .unwrap();
460 assert_eq!(pool.active_count().await, 1);
461 }
462
463 #[tokio::test]
464 async fn test_max_agents_limit() {
465 let pool = make_pool(2);
466 let cfg = || {
467 Some(TaskAgentConfig {
468 validation_config: None,
469 ..Default::default()
470 })
471 };
472
473 pool.spawn_agent(Task::new("t-1", "T1"), cfg())
474 .await
475 .unwrap();
476 pool.spawn_agent(Task::new("t-2", "T2"), cfg())
477 .await
478 .unwrap();
479
480 let err = pool.spawn_agent(Task::new("t-3", "T3"), cfg()).await;
481 assert!(err.is_err());
482 assert!(err.unwrap_err().to_string().contains("full"));
483 }
484
485 #[tokio::test]
486 async fn test_await_completion() {
487 let pool = make_pool(5);
488 let id = pool
489 .spawn_agent(
490 Task::new("t-1", "Finish me"),
491 Some(TaskAgentConfig {
492 validation_config: None,
493 ..Default::default()
494 }),
495 )
496 .await
497 .unwrap();
498
499 let result = pool.await_completion(&id).await.unwrap();
500 assert!(result.success);
501 assert_eq!(result.task_id, "t-1");
502 }
503
504 #[tokio::test]
505 async fn test_stop_agent() {
506 let pool = make_pool(5);
507 let id = pool.spawn_agent(Task::new("t-1", "T"), None).await.unwrap();
508
509 pool.stop_agent(&id).await.unwrap();
510 assert_eq!(pool.active_count().await, 0);
511 }
512
513 #[tokio::test]
514 async fn test_shutdown() {
515 let pool = make_pool(5);
516 pool.spawn_agent(Task::new("t-1", "T1"), None)
517 .await
518 .unwrap();
519 pool.spawn_agent(Task::new("t-2", "T2"), None)
520 .await
521 .unwrap();
522
523 pool.shutdown().await;
524 assert_eq!(pool.active_count().await, 0);
525 }
526
527 #[tokio::test]
528 async fn test_stats() {
529 let pool = make_pool(10);
530 let stats = pool.stats().await;
531 assert_eq!(stats.max_agents, 10);
532 assert_eq!(stats.total_agents, 0);
533 }
534
535 #[tokio::test]
536 async fn test_list_active() {
537 let pool = make_pool(5);
538 pool.spawn_agent(Task::new("t-1", "T1"), None)
539 .await
540 .unwrap();
541 pool.spawn_agent(Task::new("t-2", "T2"), None)
542 .await
543 .unwrap();
544
545 let active = pool.list_active().await;
546 assert_eq!(active.len(), 2);
547 }
548}