use std::thread::available_parallelism;
use dashmap::DashMap;
use futures::{StreamExt, stream};
use tokio::sync::mpsc;
use tracing::{debug, error, info};
use crate::structs::{
agent::{Agent, AgentError},
conversation::AgentConversation,
};
#[derive(Debug, thiserror::Error)]
pub enum BatchExecutionError {
#[error("Agent error: {0}")]
AgentError(#[from] AgentError),
#[error("No agents provided")]
NoAgents,
#[error("No tasks provided")]
NoTasks,
#[error("Channel error: {0}")]
ChannelError(#[from] mpsc::error::SendError<(String, Result<AgentConversation, AgentError>)>),
}
#[derive(Debug, Clone)]
pub struct BatchConfig {
pub max_concurrent_tasks: Option<usize>,
pub auto_cpu_optimization: bool,
pub worker_threads: Option<usize>,
}
impl Default for BatchConfig {
fn default() -> Self {
Self {
max_concurrent_tasks: None,
auto_cpu_optimization: true,
worker_threads: None,
}
}
}
#[derive(Default)]
pub struct BatchConfigBuilder {
config: BatchConfig,
}
impl BatchConfigBuilder {
pub fn max_concurrent_tasks(mut self, max: usize) -> Self {
self.config.max_concurrent_tasks = Some(max);
self
}
pub fn auto_cpu_optimization(mut self, enable: bool) -> Self {
self.config.auto_cpu_optimization = enable;
self
}
pub fn worker_threads(mut self, threads: usize) -> Self {
self.config.worker_threads = Some(threads);
self
}
pub fn build(self) -> BatchConfig {
self.config
}
}
pub struct AgentBatchExecutor {
agents: Vec<Box<dyn Agent>>,
config: BatchConfig,
}
impl AgentBatchExecutor {
pub fn new(agents: Vec<Box<dyn Agent>>, config: BatchConfig) -> Self {
Self { agents, config }
}
pub fn builder() -> AgentBatchExecutorBuilder {
AgentBatchExecutorBuilder::default()
}
fn calculate_optimal_threads(&self) -> usize {
if let Some(threads) = self.config.worker_threads {
return threads;
}
if !self.config.auto_cpu_optimization {
return 4; }
match available_parallelism() {
Ok(num_cpus) => {
let cpus = num_cpus.get();
debug!("Detected {} CPU cores", cpus);
cpus
},
Err(e) => {
error!("Failed to determine CPU count: {}", e);
4 },
}
}
pub async fn execute_batch(
&self,
tasks: Vec<String>,
) -> Result<DashMap<String, AgentConversation>, BatchExecutionError> {
if self.agents.is_empty() {
return Err(BatchExecutionError::NoAgents);
}
if tasks.is_empty() {
return Err(BatchExecutionError::NoTasks);
}
let results = DashMap::with_capacity(tasks.len());
let (tx, mut rx) = mpsc::channel(tasks.len());
let max_concurrent = self
.config
.max_concurrent_tasks
.unwrap_or_else(|| self.calculate_optimal_threads());
info!(
"Starting batch execution with {} tasks across {} agents (max concurrent: {})",
tasks.len(),
self.agents.len(),
max_concurrent
);
stream::iter(tasks)
.for_each_concurrent(max_concurrent, |task| {
let tx = tx.clone();
let agents = &self.agents;
async move {
for agent in agents {
match agent.run(task.clone()).await {
Ok(response) => {
let mut conversation = AgentConversation::new(agent.name());
conversation.add(
crate::structs::conversation::Role::Assistant(agent.name()),
response,
);
tx.send((task.clone(), Ok(conversation))).await.unwrap();
},
Err(e) => {
error!(
"Agent {} failed to process task '{}': {}",
agent.name(),
task,
e
);
tx.send((task.clone(), Err(e))).await.unwrap();
},
}
}
}
})
.await;
drop(tx);
while let Some((task, result)) = rx.recv().await {
match result {
Ok(conversation) => {
results.insert(task, conversation);
},
Err(e) => {
error!("Task failed: {}", e);
},
}
}
info!("Batch execution completed with {} results", results.len());
Ok(results)
}
}
#[derive(Default)]
pub struct AgentBatchExecutorBuilder {
agents: Vec<Box<dyn Agent>>,
config: BatchConfig,
}
impl AgentBatchExecutorBuilder {
pub fn add_agent(mut self, agent: Box<dyn Agent>) -> Self {
self.agents.push(agent);
self
}
pub fn config(mut self, config: BatchConfig) -> Self {
self.config = config;
self
}
pub fn build(self) -> AgentBatchExecutor {
AgentBatchExecutor::new(self.agents, self.config)
}
}