use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::mpsc;
use tokio::task::JoinSet;
use tracing::{debug, info, warn};
use super::driver::{LlmDriver, StreamEvent};
use super::manifest::AgentManifest;
use super::memory::{InMemorySubstrate, MemorySubstrate};
use super::result::{AgentError, AgentLoopResult};
use super::tool::ToolRegistry;
pub type AgentId = u64;
#[derive(Debug, Clone)]
pub struct AgentMessage {
pub from: AgentId,
pub to: AgentId,
pub content: String,
}
pub struct SpawnConfig {
pub manifest: AgentManifest,
pub query: String,
}
#[derive(Clone)]
pub struct MessageRouter {
inboxes: Arc<std::sync::RwLock<HashMap<AgentId, mpsc::Sender<AgentMessage>>>>,
inbox_capacity: usize,
}
impl MessageRouter {
pub fn new(inbox_capacity: usize) -> Self {
Self { inboxes: Arc::new(std::sync::RwLock::new(HashMap::new())), inbox_capacity }
}
pub fn register(&self, agent_id: AgentId) -> mpsc::Receiver<AgentMessage> {
let (tx, rx) = mpsc::channel(self.inbox_capacity);
let mut inboxes = self.inboxes.write().expect("message router lock");
inboxes.insert(agent_id, tx);
rx
}
pub fn unregister(&self, agent_id: AgentId) {
let mut inboxes = self.inboxes.write().expect("message router lock");
inboxes.remove(&agent_id);
}
pub async fn send(&self, msg: AgentMessage) -> Result<(), String> {
let tx = {
let inboxes = self.inboxes.read().expect("message router lock");
inboxes
.get(&msg.to)
.cloned()
.ok_or_else(|| format!("agent {} not registered", msg.to))?
};
tx.send(msg).await.map_err(|e| format!("inbox closed: {e}"))
}
pub fn agent_count(&self) -> usize {
let inboxes = self.inboxes.read().expect("message router lock");
inboxes.len()
}
}
pub type ToolBuilder = Arc<dyn Fn(&AgentManifest) -> ToolRegistry + Send + Sync>;
pub struct AgentPool {
driver: Arc<dyn LlmDriver>,
memory: Arc<dyn MemorySubstrate>,
next_id: AgentId,
max_concurrent: usize,
join_set: JoinSet<(AgentId, String, Result<AgentLoopResult, String>)>,
stream_tx: Option<mpsc::Sender<StreamEvent>>,
router: MessageRouter,
tool_builder: Option<ToolBuilder>,
}
impl AgentPool {
pub fn new(driver: Arc<dyn LlmDriver>, max_concurrent: usize) -> Self {
Self {
driver,
memory: Arc::new(InMemorySubstrate::new()),
next_id: 1,
max_concurrent,
join_set: JoinSet::new(),
stream_tx: None,
router: MessageRouter::new(32),
tool_builder: None,
}
}
pub fn router(&self) -> &MessageRouter {
&self.router
}
#[must_use]
pub fn with_memory(mut self, memory: Arc<dyn MemorySubstrate>) -> Self {
self.memory = memory;
self
}
#[must_use]
pub fn with_stream(mut self, tx: mpsc::Sender<StreamEvent>) -> Self {
self.stream_tx = Some(tx);
self
}
#[must_use]
pub fn with_tool_builder(mut self, builder: ToolBuilder) -> Self {
self.tool_builder = Some(builder);
self
}
pub fn active_count(&self) -> usize {
self.join_set.len()
}
pub fn max_concurrent(&self) -> usize {
self.max_concurrent
}
pub fn spawn(&mut self, config: SpawnConfig) -> Result<AgentId, AgentError> {
if self.join_set.len() >= self.max_concurrent {
return Err(AgentError::CircuitBreak(format!(
"agent pool at capacity ({}/{})",
self.join_set.len(),
self.max_concurrent
)));
}
let id = self.next_id;
self.next_id += 1;
let name = config.manifest.name.clone();
let driver = Arc::clone(&self.driver);
let memory = Arc::clone(&self.memory);
let stream_tx = self.stream_tx.clone();
let _inbox_rx = self.router.register(id);
let router = self.router.clone();
info!(
agent_id = id,
name = %name,
query_len = config.query.len(),
"spawning agent"
);
let tool_builder = self.tool_builder.clone();
self.join_set.spawn(async move {
let tools = match tool_builder {
Some(builder) => builder(&config.manifest),
None => ToolRegistry::new(),
};
let result = super::runtime::run_agent_loop(
&config.manifest,
&config.query,
driver.as_ref(),
&tools,
memory.as_ref(),
stream_tx,
)
.await;
router.unregister(id);
let mapped = result.map_err(|e| e.to_string());
(id, name, mapped)
});
Ok(id)
}
pub fn fan_out(&mut self, configs: Vec<SpawnConfig>) -> Result<Vec<AgentId>, AgentError> {
let mut ids = Vec::with_capacity(configs.len());
for config in configs {
ids.push(self.spawn(config)?);
}
Ok(ids)
}
pub async fn join_all(&mut self) -> HashMap<AgentId, Result<AgentLoopResult, String>> {
let mut results = HashMap::new();
while let Some(outcome) = self.join_set.join_next().await {
match outcome {
Ok((id, name, result)) => {
debug!(
agent_id = id,
name = %name,
ok = result.is_ok(),
"agent completed"
);
results.insert(id, result);
}
Err(e) => {
warn!(error = %e, "agent task panicked");
}
}
}
results
}
pub async fn join_next(&mut self) -> Option<(AgentId, Result<AgentLoopResult, String>)> {
match self.join_set.join_next().await {
Some(Ok((id, _name, result))) => Some((id, result)),
Some(Err(e)) => {
warn!(error = %e, "agent task panicked");
None
}
None => None,
}
}
pub fn abort_all(&mut self) {
self.join_set.abort_all();
info!("all agents aborted");
}
}
#[cfg(test)]
#[path = "pool_tests.rs"]
mod tests;