1use std::collections::HashMap;
14use std::sync::Arc;
15
16use tokio::sync::mpsc;
17use tokio::task::JoinSet;
18use tracing::{debug, info, warn};
19
20use super::driver::{LlmDriver, StreamEvent};
21use super::manifest::AgentManifest;
22use super::memory::{InMemorySubstrate, MemorySubstrate};
23use super::result::{AgentError, AgentLoopResult};
24use super::tool::ToolRegistry;
25
26pub type AgentId = u64;
28
29#[derive(Debug, Clone)]
31pub struct AgentMessage {
32 pub from: AgentId,
34 pub to: AgentId,
36 pub content: String,
38}
39
40pub struct SpawnConfig {
42 pub manifest: AgentManifest,
44 pub query: String,
46}
47
48#[derive(Clone)]
54pub struct MessageRouter {
55 inboxes: Arc<std::sync::RwLock<HashMap<AgentId, mpsc::Sender<AgentMessage>>>>,
56 inbox_capacity: usize,
57}
58
59impl MessageRouter {
60 pub fn new(inbox_capacity: usize) -> Self {
62 Self { inboxes: Arc::new(std::sync::RwLock::new(HashMap::new())), inbox_capacity }
63 }
64
65 pub fn register(&self, agent_id: AgentId) -> mpsc::Receiver<AgentMessage> {
67 let (tx, rx) = mpsc::channel(self.inbox_capacity);
68 let mut inboxes = self.inboxes.write().expect("message router lock");
69 inboxes.insert(agent_id, tx);
70 rx
71 }
72
73 pub fn unregister(&self, agent_id: AgentId) {
75 let mut inboxes = self.inboxes.write().expect("message router lock");
76 inboxes.remove(&agent_id);
77 }
78
79 pub async fn send(&self, msg: AgentMessage) -> Result<(), String> {
84 let tx = {
85 let inboxes = self.inboxes.read().expect("message router lock");
86 inboxes
87 .get(&msg.to)
88 .cloned()
89 .ok_or_else(|| format!("agent {} not registered", msg.to))?
90 };
91 tx.send(msg).await.map_err(|e| format!("inbox closed: {e}"))
92 }
93
94 pub fn agent_count(&self) -> usize {
96 let inboxes = self.inboxes.read().expect("message router lock");
97 inboxes.len()
98 }
99}
100
101pub type ToolBuilder = Arc<dyn Fn(&AgentManifest) -> ToolRegistry + Send + Sync>;
103
104pub struct AgentPool {
117 driver: Arc<dyn LlmDriver>,
118 memory: Arc<dyn MemorySubstrate>,
119 next_id: AgentId,
120 max_concurrent: usize,
121 join_set: JoinSet<(AgentId, String, Result<AgentLoopResult, String>)>,
122 stream_tx: Option<mpsc::Sender<StreamEvent>>,
123 router: MessageRouter,
124 tool_builder: Option<ToolBuilder>,
125}
126
127impl AgentPool {
128 pub fn new(driver: Arc<dyn LlmDriver>, max_concurrent: usize) -> Self {
130 Self {
131 driver,
132 memory: Arc::new(InMemorySubstrate::new()),
133 next_id: 1,
134 max_concurrent,
135 join_set: JoinSet::new(),
136 stream_tx: None,
137 router: MessageRouter::new(32),
138 tool_builder: None,
139 }
140 }
141
142 pub fn router(&self) -> &MessageRouter {
144 &self.router
145 }
146
147 #[must_use]
149 pub fn with_memory(mut self, memory: Arc<dyn MemorySubstrate>) -> Self {
150 self.memory = memory;
151 self
152 }
153
154 #[must_use]
156 pub fn with_stream(mut self, tx: mpsc::Sender<StreamEvent>) -> Self {
157 self.stream_tx = Some(tx);
158 self
159 }
160
161 #[must_use]
166 pub fn with_tool_builder(mut self, builder: ToolBuilder) -> Self {
167 self.tool_builder = Some(builder);
168 self
169 }
170
171 pub fn active_count(&self) -> usize {
173 self.join_set.len()
174 }
175
176 pub fn max_concurrent(&self) -> usize {
178 self.max_concurrent
179 }
180
181 pub fn spawn(&mut self, config: SpawnConfig) -> Result<AgentId, AgentError> {
186 if self.join_set.len() >= self.max_concurrent {
187 return Err(AgentError::CircuitBreak(format!(
188 "agent pool at capacity ({}/{})",
189 self.join_set.len(),
190 self.max_concurrent
191 )));
192 }
193
194 let id = self.next_id;
195 self.next_id += 1;
196
197 let name = config.manifest.name.clone();
198 let driver = Arc::clone(&self.driver);
199 let memory = Arc::clone(&self.memory);
200 let stream_tx = self.stream_tx.clone();
201
202 let _inbox_rx = self.router.register(id);
204 let router = self.router.clone();
205
206 info!(
207 agent_id = id,
208 name = %name,
209 query_len = config.query.len(),
210 "spawning agent"
211 );
212
213 let tool_builder = self.tool_builder.clone();
214
215 self.join_set.spawn(async move {
216 let tools = match tool_builder {
217 Some(builder) => builder(&config.manifest),
218 None => ToolRegistry::new(),
219 };
220 let result = super::runtime::run_agent_loop(
221 &config.manifest,
222 &config.query,
223 driver.as_ref(),
224 &tools,
225 memory.as_ref(),
226 stream_tx,
227 )
228 .await;
229
230 router.unregister(id);
232
233 let mapped = result.map_err(|e| e.to_string());
235 (id, name, mapped)
236 });
237
238 Ok(id)
239 }
240
241 pub fn fan_out(&mut self, configs: Vec<SpawnConfig>) -> Result<Vec<AgentId>, AgentError> {
245 let mut ids = Vec::with_capacity(configs.len());
246 for config in configs {
247 ids.push(self.spawn(config)?);
248 }
249 Ok(ids)
250 }
251
252 pub async fn join_all(&mut self) -> HashMap<AgentId, Result<AgentLoopResult, String>> {
257 let mut results = HashMap::new();
258
259 while let Some(outcome) = self.join_set.join_next().await {
260 match outcome {
261 Ok((id, name, result)) => {
262 debug!(
263 agent_id = id,
264 name = %name,
265 ok = result.is_ok(),
266 "agent completed"
267 );
268 results.insert(id, result);
269 }
270 Err(e) => {
271 warn!(error = %e, "agent task panicked");
272 }
273 }
274 }
275
276 results
277 }
278
279 pub async fn join_next(&mut self) -> Option<(AgentId, Result<AgentLoopResult, String>)> {
283 match self.join_set.join_next().await {
284 Some(Ok((id, _name, result))) => Some((id, result)),
285 Some(Err(e)) => {
286 warn!(error = %e, "agent task panicked");
287 None
288 }
289 None => None,
290 }
291 }
292
293 pub fn abort_all(&mut self) {
295 self.join_set.abort_all();
296 info!("all agents aborted");
297 }
298}
299
300#[cfg(test)]
301#[path = "pool_tests.rs"]
302mod tests;