Skip to main content

batuta/agent/
pool.rs

1//! Multi-agent orchestration pool.
2//!
3//! Manages concurrent agent instances with message passing
4//! and fan-out/fan-in patterns. Each agent runs its own
5//! perceive-reason-act loop in a separate tokio task.
6//!
7//! # Toyota Production System Principles
8//!
9//! - **Heijunka**: Load-level work across agents
10//! - **Jidoka**: Each agent has its own `LoopGuard`
11//! - **Muda**: Bounded concurrency prevents resource waste
12
13use std::collections::HashMap;
14use std::sync::Arc;
15
16use tokio::sync::mpsc;
17use tokio::task::JoinSet;
18use tracing::{debug, info, warn};
19
20use super::driver::{LlmDriver, StreamEvent};
21use super::manifest::AgentManifest;
22use super::memory::{InMemorySubstrate, MemorySubstrate};
23use super::result::{AgentError, AgentLoopResult};
24use super::tool::ToolRegistry;
25
26/// Unique identifier for a spawned agent.
27pub type AgentId = u64;
28
29/// Message sent between agents in the pool.
30#[derive(Debug, Clone)]
31pub struct AgentMessage {
32    /// Source agent ID (0 = external/supervisor).
33    pub from: AgentId,
34    /// Target agent ID.
35    pub to: AgentId,
36    /// Message payload.
37    pub content: String,
38}
39
40/// Configuration for a spawned agent.
41pub struct SpawnConfig {
42    /// Agent manifest.
43    pub manifest: AgentManifest,
44    /// Query to execute.
45    pub query: String,
46}
47
48/// Routes messages between agents in a pool.
49///
50/// Each agent gets an inbox (bounded `mpsc` channel). The router
51/// holds senders keyed by `AgentId`, so any agent can send to any
52/// other agent via the shared router reference.
53#[derive(Clone)]
54pub struct MessageRouter {
55    inboxes: Arc<std::sync::RwLock<HashMap<AgentId, mpsc::Sender<AgentMessage>>>>,
56    inbox_capacity: usize,
57}
58
59impl MessageRouter {
60    /// Create a new message router.
61    pub fn new(inbox_capacity: usize) -> Self {
62        Self { inboxes: Arc::new(std::sync::RwLock::new(HashMap::new())), inbox_capacity }
63    }
64
65    /// Register an agent inbox, returning the receiver.
66    pub fn register(&self, agent_id: AgentId) -> mpsc::Receiver<AgentMessage> {
67        let (tx, rx) = mpsc::channel(self.inbox_capacity);
68        let mut inboxes = self.inboxes.write().expect("message router lock");
69        inboxes.insert(agent_id, tx);
70        rx
71    }
72
73    /// Unregister an agent (removes its inbox sender).
74    pub fn unregister(&self, agent_id: AgentId) {
75        let mut inboxes = self.inboxes.write().expect("message router lock");
76        inboxes.remove(&agent_id);
77    }
78
79    /// Send a message to a target agent.
80    ///
81    /// Returns `Err` if target agent is not registered or inbox
82    /// is full (bounded channel protects against backpressure).
83    pub async fn send(&self, msg: AgentMessage) -> Result<(), String> {
84        let tx = {
85            let inboxes = self.inboxes.read().expect("message router lock");
86            inboxes
87                .get(&msg.to)
88                .cloned()
89                .ok_or_else(|| format!("agent {} not registered", msg.to))?
90        };
91        tx.send(msg).await.map_err(|e| format!("inbox closed: {e}"))
92    }
93
94    /// Number of registered agents.
95    pub fn agent_count(&self) -> usize {
96        let inboxes = self.inboxes.read().expect("message router lock");
97        inboxes.len()
98    }
99}
100
101/// Function that builds a `ToolRegistry` from a manifest.
102pub type ToolBuilder = Arc<dyn Fn(&AgentManifest) -> ToolRegistry + Send + Sync>;
103
104/// Multi-agent orchestration pool.
105///
106/// Manages concurrent agent instances, each running its own
107/// perceive-reason-act loop. Supports fan-out (spawn many) and
108/// fan-in (collect results) patterns.
109///
110/// ```rust,ignore
111/// let mut pool = AgentPool::new(driver, 4);
112/// pool.spawn(config1).await?;
113/// pool.spawn(config2).await?;
114/// let results = pool.join_all().await;
115/// ```
116pub struct AgentPool {
117    driver: Arc<dyn LlmDriver>,
118    memory: Arc<dyn MemorySubstrate>,
119    next_id: AgentId,
120    max_concurrent: usize,
121    join_set: JoinSet<(AgentId, String, Result<AgentLoopResult, String>)>,
122    stream_tx: Option<mpsc::Sender<StreamEvent>>,
123    router: MessageRouter,
124    tool_builder: Option<ToolBuilder>,
125}
126
127impl AgentPool {
128    /// Create a new agent pool with bounded concurrency.
129    pub fn new(driver: Arc<dyn LlmDriver>, max_concurrent: usize) -> Self {
130        Self {
131            driver,
132            memory: Arc::new(InMemorySubstrate::new()),
133            next_id: 1,
134            max_concurrent,
135            join_set: JoinSet::new(),
136            stream_tx: None,
137            router: MessageRouter::new(32),
138            tool_builder: None,
139        }
140    }
141
142    /// Access the message router for inter-agent messaging.
143    pub fn router(&self) -> &MessageRouter {
144        &self.router
145    }
146
147    /// Set a shared memory substrate for all agents.
148    #[must_use]
149    pub fn with_memory(mut self, memory: Arc<dyn MemorySubstrate>) -> Self {
150        self.memory = memory;
151        self
152    }
153
154    /// Set a stream event channel for pool-level events.
155    #[must_use]
156    pub fn with_stream(mut self, tx: mpsc::Sender<StreamEvent>) -> Self {
157        self.stream_tx = Some(tx);
158        self
159    }
160
161    /// Set a tool builder for spawned agents.
162    ///
163    /// When set, each spawned agent gets tools built from its
164    /// manifest rather than an empty registry.
165    #[must_use]
166    pub fn with_tool_builder(mut self, builder: ToolBuilder) -> Self {
167        self.tool_builder = Some(builder);
168        self
169    }
170
171    /// Number of currently active agents.
172    pub fn active_count(&self) -> usize {
173        self.join_set.len()
174    }
175
176    /// Maximum concurrent agents allowed.
177    pub fn max_concurrent(&self) -> usize {
178        self.max_concurrent
179    }
180
181    /// Spawn a new agent in the pool.
182    ///
183    /// Returns the `AgentId` assigned to this agent.
184    /// Returns error if pool is at capacity.
185    pub fn spawn(&mut self, config: SpawnConfig) -> Result<AgentId, AgentError> {
186        if self.join_set.len() >= self.max_concurrent {
187            return Err(AgentError::CircuitBreak(format!(
188                "agent pool at capacity ({}/{})",
189                self.join_set.len(),
190                self.max_concurrent
191            )));
192        }
193
194        let id = self.next_id;
195        self.next_id += 1;
196
197        let name = config.manifest.name.clone();
198        let driver = Arc::clone(&self.driver);
199        let memory = Arc::clone(&self.memory);
200        let stream_tx = self.stream_tx.clone();
201
202        // Register agent inbox for inter-agent messaging
203        let _inbox_rx = self.router.register(id);
204        let router = self.router.clone();
205
206        info!(
207            agent_id = id,
208            name = %name,
209            query_len = config.query.len(),
210            "spawning agent"
211        );
212
213        let tool_builder = self.tool_builder.clone();
214
215        self.join_set.spawn(async move {
216            let tools = match tool_builder {
217                Some(builder) => builder(&config.manifest),
218                None => ToolRegistry::new(),
219            };
220            let result = super::runtime::run_agent_loop(
221                &config.manifest,
222                &config.query,
223                driver.as_ref(),
224                &tools,
225                memory.as_ref(),
226                stream_tx,
227            )
228            .await;
229
230            // Unregister agent from router on completion
231            router.unregister(id);
232
233            // Map error to String to avoid Clone requirement
234            let mapped = result.map_err(|e| e.to_string());
235            (id, name, mapped)
236        });
237
238        Ok(id)
239    }
240
241    /// Fan-out: spawn multiple agents concurrently.
242    ///
243    /// Returns a list of `AgentId`s for the spawned agents.
244    pub fn fan_out(&mut self, configs: Vec<SpawnConfig>) -> Result<Vec<AgentId>, AgentError> {
245        let mut ids = Vec::with_capacity(configs.len());
246        for config in configs {
247            ids.push(self.spawn(config)?);
248        }
249        Ok(ids)
250    }
251
252    /// Fan-in: wait for all active agents to complete.
253    ///
254    /// Returns results keyed by `AgentId`. Agents that error
255    /// are included with their error string.
256    pub async fn join_all(&mut self) -> HashMap<AgentId, Result<AgentLoopResult, String>> {
257        let mut results = HashMap::new();
258
259        while let Some(outcome) = self.join_set.join_next().await {
260            match outcome {
261                Ok((id, name, result)) => {
262                    debug!(
263                        agent_id = id,
264                        name = %name,
265                        ok = result.is_ok(),
266                        "agent completed"
267                    );
268                    results.insert(id, result);
269                }
270                Err(e) => {
271                    warn!(error = %e, "agent task panicked");
272                }
273            }
274        }
275
276        results
277    }
278
279    /// Wait for the next agent to complete.
280    ///
281    /// Returns `None` if no agents are active.
282    pub async fn join_next(&mut self) -> Option<(AgentId, Result<AgentLoopResult, String>)> {
283        match self.join_set.join_next().await {
284            Some(Ok((id, _name, result))) => Some((id, result)),
285            Some(Err(e)) => {
286                warn!(error = %e, "agent task panicked");
287                None
288            }
289            None => None,
290        }
291    }
292
293    /// Abort all running agents.
294    pub fn abort_all(&mut self) {
295        self.join_set.abort_all();
296        info!("all agents aborted");
297    }
298}
299
300#[cfg(test)]
301#[path = "pool_tests.rs"]
302mod tests;