use async_trait::async_trait;
use crate::models::graph::Agent;
use crate::models::tools::ToolRegistryTrait;
use std::sync::Arc;
use tokio::sync::Mutex;
#[derive(Clone)]
pub enum CombineStrategy {
Concatenate,
FirstValid,
JsonArray,
Custom(Arc<dyn Fn(Vec<String>) -> String + Send + Sync>),
}
pub struct ParallelAgent {
agents: Vec<Arc<Mutex<Box<dyn Agent>>>>,
strategy: CombineStrategy,
name: String,
timeout: Option<tokio::time::Duration>,
}
impl ParallelAgent {
pub fn new() -> Self {
Self {
agents: Vec::new(),
strategy: CombineStrategy::Concatenate,
name: "Parallel".to_string(),
timeout: None,
}
}
pub fn add_agent(mut self, agent: Box<dyn Agent>) -> Self {
self.agents.push(Arc::new(Mutex::new(agent)));
self
}
pub fn with_strategy(mut self, strategy: CombineStrategy) -> Self {
self.strategy = strategy;
self
}
pub fn with_timeout(mut self, timeout: tokio::time::Duration) -> Self {
self.timeout = Some(timeout);
self
}
pub fn with_name(mut self, name: impl Into<String>) -> Self {
self.name = name.into();
self
}
fn combine_results(&self, results: Vec<String>) -> String {
match &self.strategy {
CombineStrategy::Concatenate => results.join("\n"),
CombineStrategy::FirstValid => {
results.into_iter()
.find(|s| !s.is_empty())
.unwrap_or_default()
}
CombineStrategy::JsonArray => {
serde_json::json!(results).to_string()
}
CombineStrategy::Custom(combiner) => combiner(results),
}
}
}
impl Default for ParallelAgent {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl Agent for ParallelAgent {
async fn run(
&mut self,
input: &str,
tool_registry: &(dyn ToolRegistryTrait + Send + Sync),
) -> (String, Option<i32>) {
if self.agents.is_empty() {
return ("No agents configured for parallel execution".to_string(), None);
}
let mut futures = Vec::new();
for agent_mutex in &self.agents {
let input_clone = input.to_string();
let agent = agent_mutex.clone();
let future = async move {
let mut agent_guard = agent.lock().await;
agent_guard.run(&input_clone, tool_registry).await.0
};
futures.push(future);
}
let results = if let Some(timeout) = self.timeout {
match tokio::time::timeout(timeout, futures::future::join_all(futures)).await {
Ok(results) => results,
Err(_) => {
return ("Parallel execution timed out".to_string(), None);
}
}
} else {
futures::future::join_all(futures).await
};
let combined = self.combine_results(results);
(combined, None)
}
fn get_name(&self) -> &str {
&self.name
}
}
use futures;