use std::collections::BTreeMap;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use tokio::sync::{Mutex, Semaphore};
use tokio::task::JoinSet;
use crate::error::{AgenticError, Result};
use super::output::AgentOutput;
use super::werk::Agent;
const DEFAULT_BATCH_SIZE: usize = 10;
pub type JobId = u64;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum PoolStrategy {
CompletionOrder,
SpawnOrder,
}
impl Default for PoolStrategy {
fn default() -> Self {
Self::CompletionOrder
}
}
pub struct AgentPool {
batch_size: usize,
ordering: PoolStrategy,
semaphore: Arc<Semaphore>,
state: Mutex<PoolState>,
next_id: AtomicU64,
}
struct PoolState {
join_set: JoinSet<(JobId, Result<AgentOutput>)>,
buffer: BTreeMap<JobId, Result<AgentOutput>>,
next_expected: JobId,
}
impl AgentPool {
pub fn new() -> Self {
Self {
batch_size: DEFAULT_BATCH_SIZE,
ordering: PoolStrategy::default(),
semaphore: Arc::new(Semaphore::new(DEFAULT_BATCH_SIZE)),
state: Mutex::new(PoolState {
join_set: JoinSet::new(),
buffer: BTreeMap::new(),
next_expected: 0,
}),
next_id: AtomicU64::new(0),
}
}
pub fn batch_size(mut self, n: usize) -> Self {
self.batch_size = n.max(1);
self.semaphore = Arc::new(Semaphore::new(self.batch_size));
self
}
pub fn ordering(mut self, o: PoolStrategy) -> Self {
self.ordering = o;
self
}
pub async fn spawn(&self, agent: Agent) -> JobId {
let id = self.next_id.fetch_add(1, Ordering::Relaxed);
let permit = self
.semaphore
.clone()
.acquire_owned()
.await
.expect("pool semaphore closed unexpectedly");
let mut st = self.state.lock().await;
st.join_set.spawn(async move {
let result = agent.run().await;
drop(permit);
(id, result)
});
id
}
pub async fn next(&self) -> Option<(JobId, Result<AgentOutput>)> {
match self.ordering {
PoolStrategy::CompletionOrder => self.next_by_completion().await,
PoolStrategy::SpawnOrder => self.next_by_spawn_order().await,
}
}
pub async fn drain(&self) -> Vec<(JobId, Result<AgentOutput>)> {
let mut out = Vec::new();
while let Some(entry) = self.next().await {
out.push(entry);
}
out
}
async fn next_by_completion(&self) -> Option<(JobId, Result<AgentOutput>)> {
let mut st = self.state.lock().await;
if st.join_set.is_empty() {
return None;
}
match st.join_set.join_next().await {
Some(Ok(pair)) => Some(pair),
Some(Err(join_err)) => Some((
u64::MAX,
Err(AgenticError::Other(format!("task join error: {join_err}"))),
)),
None => None,
}
}
async fn next_by_spawn_order(&self) -> Option<(JobId, Result<AgentOutput>)> {
loop {
let mut st = self.state.lock().await;
let next_id = st.next_expected;
if let Some(r) = st.buffer.remove(&next_id) {
st.next_expected = next_id + 1;
return Some((next_id, r));
}
if st.join_set.is_empty() && st.buffer.is_empty() {
return None;
}
let next = match st.join_set.join_next().await {
Some(Ok(pair)) => pair,
Some(Err(join_err)) => (
u64::MAX,
Err(AgenticError::Other(format!("task join error: {join_err}"))),
),
None => return None,
};
st.buffer.insert(next.0, next.1);
}
}
}
impl Default for AgentPool {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::testutil::{text_response, tool_response, MockProvider};
use crate::tools::{ToolBuilder, ToolResult};
use std::sync::atomic::AtomicUsize;
use std::time::Duration;
fn agent_with_response(text: &str) -> Agent {
let provider = Arc::new(MockProvider::text(text));
Agent::new()
.name("test")
.model("mock")
.identity_prompt("")
.instruction_prompt("go")
.provider(provider)
}
#[tokio::test]
async fn pool_drain_spawn_order() {
let pool = AgentPool::new()
.batch_size(2)
.ordering(PoolStrategy::SpawnOrder);
pool.spawn(agent_with_response("first")).await;
pool.spawn(agent_with_response("second")).await;
pool.spawn(agent_with_response("third")).await;
let results = pool.drain().await;
assert_eq!(results.len(), 3);
assert_eq!(results[0].0, 0);
assert_eq!(results[1].0, 1);
assert_eq!(results[2].0, 2);
assert_eq!(results[0].1.as_ref().unwrap().response_raw, "first");
assert_eq!(results[1].1.as_ref().unwrap().response_raw, "second");
assert_eq!(results[2].1.as_ref().unwrap().response_raw, "third");
}
#[tokio::test]
async fn pool_individual_failures() {
let pool = AgentPool::new()
.batch_size(2)
.ordering(PoolStrategy::SpawnOrder);
pool.spawn(agent_with_response("ok")).await;
pool.spawn({
let provider = Arc::new(MockProvider::new(vec![]));
Agent::new()
.name("fail")
.model("mock")
.identity_prompt("")
.instruction_prompt("go")
.provider(provider)
})
.await;
pool.spawn(agent_with_response("also ok")).await;
let results = pool.drain().await;
assert_eq!(results.len(), 3);
assert!(results[0].1.is_ok());
assert!(results[1].1.is_err());
assert!(results[2].1.is_ok());
}
#[tokio::test]
async fn pool_empty() {
let pool = AgentPool::new();
let results = pool.drain().await;
assert!(results.is_empty());
}
#[tokio::test]
async fn pool_runs_concurrently() {
let running = Arc::new(AtomicUsize::new(0));
let max_concurrent = Arc::new(AtomicUsize::new(0));
let pool = AgentPool::new().batch_size(3);
for _ in 0..6 {
let running = running.clone();
let max_concurrent = max_concurrent.clone();
let slow_tool = ToolBuilder::new("slow", "Simulates slow work")
.schema(serde_json::json!({"type": "object", "properties": {}}))
.handler(move |_, _| {
let running = running.clone();
let max_concurrent = max_concurrent.clone();
Box::pin(async move {
let current = running.fetch_add(1, Ordering::SeqCst) + 1;
max_concurrent.fetch_max(current, Ordering::SeqCst);
tokio::time::sleep(Duration::from_millis(50)).await;
running.fetch_sub(1, Ordering::SeqCst);
Ok(ToolResult::success("done"))
})
})
.build();
let provider = Arc::new(MockProvider::new(vec![
tool_response("slow", "c1", serde_json::json!({})),
text_response("finished"),
]));
pool.spawn(
Agent::new()
.name("worker")
.model("mock")
.identity_prompt("")
.instruction_prompt("go")
.tool(slow_tool)
.provider(provider),
)
.await;
}
let results = pool.drain().await;
assert_eq!(results.len(), 6);
assert!(results.iter().all(|r| r.1.is_ok()));
assert!(
max_concurrent.load(Ordering::SeqCst) >= 3,
"Expected at least 3 concurrent agents, got {}",
max_concurrent.load(Ordering::SeqCst)
);
}
#[tokio::test]
async fn pool_spawn_order_buffers_fast_finishers() {
let slow_tool = ToolBuilder::new("slow", "slow tool")
.schema(serde_json::json!({"type": "object", "properties": {}}))
.handler(|_, _| {
Box::pin(async move {
tokio::time::sleep(Duration::from_millis(80)).await;
Ok(ToolResult::success("slow done"))
})
})
.build();
let a = Agent::new()
.name("A")
.model("mock")
.identity_prompt("")
.instruction_prompt("go")
.tool(slow_tool)
.provider(Arc::new(MockProvider::new(vec![
tool_response("slow", "c1", serde_json::json!({})),
text_response("A-done"),
])));
let b = Agent::new()
.name("B")
.model("mock")
.identity_prompt("")
.instruction_prompt("go")
.provider(Arc::new(MockProvider::text("B-done")));
let pool = AgentPool::new()
.batch_size(2)
.ordering(PoolStrategy::SpawnOrder);
pool.spawn(a).await;
pool.spawn(b).await;
let results = pool.drain().await;
assert_eq!(results.len(), 2);
assert_eq!(results[0].0, 0);
assert_eq!(results[0].1.as_ref().unwrap().response_raw, "A-done");
assert_eq!(results[1].0, 1);
assert_eq!(results[1].1.as_ref().unwrap().response_raw, "B-done");
}
#[tokio::test]
async fn pool_completion_order_yields_fast_first() {
let slow_tool = ToolBuilder::new("slow", "slow tool")
.schema(serde_json::json!({"type": "object", "properties": {}}))
.handler(|_, _| {
Box::pin(async move {
tokio::time::sleep(Duration::from_millis(80)).await;
Ok(ToolResult::success("slow done"))
})
})
.build();
let a = Agent::new()
.name("A")
.model("mock")
.identity_prompt("")
.instruction_prompt("go")
.tool(slow_tool)
.provider(Arc::new(MockProvider::new(vec![
tool_response("slow", "c1", serde_json::json!({})),
text_response("A-done"),
])));
let b = Agent::new()
.name("B")
.model("mock")
.identity_prompt("")
.instruction_prompt("go")
.provider(Arc::new(MockProvider::text("B-done")));
let pool = AgentPool::new()
.batch_size(2)
.ordering(PoolStrategy::CompletionOrder);
pool.spawn(a).await;
pool.spawn(b).await;
let results = pool.drain().await;
assert_eq!(results.len(), 2);
assert_eq!(results[0].1.as_ref().unwrap().response_raw, "B-done");
assert_eq!(results[1].1.as_ref().unwrap().response_raw, "A-done");
}
#[tokio::test]
async fn pool_completion_order_failure_does_not_block() {
let pool = AgentPool::new()
.batch_size(2)
.ordering(PoolStrategy::CompletionOrder);
pool.spawn({
let provider = Arc::new(MockProvider::new(vec![]));
Agent::new()
.name("fail")
.model("mock")
.identity_prompt("")
.instruction_prompt("go")
.provider(provider)
})
.await;
pool.spawn(agent_with_response("ok")).await;
let results = pool.drain().await;
assert_eq!(results.len(), 2);
let ok_count = results.iter().filter(|r| r.1.is_ok()).count();
let err_count = results.iter().filter(|r| r.1.is_err()).count();
assert_eq!(ok_count, 1);
assert_eq!(err_count, 1);
}
#[tokio::test]
async fn pool_dynamic_spawn_while_running() {
let pool = Arc::new(AgentPool::new().batch_size(2));
pool.spawn(agent_with_response("first")).await;
let pool2 = pool.clone();
let spawner = tokio::spawn(async move {
tokio::time::sleep(Duration::from_millis(20)).await;
pool2.spawn(agent_with_response("second")).await;
});
let first = pool.next().await.unwrap();
assert_eq!(first.1.as_ref().unwrap().response_raw, "first");
spawner.await.unwrap();
let second = pool.next().await.unwrap();
assert_eq!(second.1.as_ref().unwrap().response_raw, "second");
assert!(pool.next().await.is_none());
}
}