use dashmap::DashMap;
use serde::Deserialize;
use crate::agent::SwarmsAgent;
use crate::llm::provider::openai::OpenAI;
use crate::prompts::multi_agent_collab_prompt::MULTI_AGENT_COLLAB_PROMPT;
use crate::structs::agent::Agent;
use crate::structs::concurrent_workflow::ConcurrentWorkflow;
use crate::structs::concurrent_workflow::ConcurrentWorkflowError;
use crate::structs::conversation::AgentConversation;
use crate::structs::sequential_workflow::SequentialWorkflow;
use crate::structs::sequential_workflow::SequentialWorkflowError;
#[derive(Debug, Deserialize)]
pub enum SwarmType {
SequentialWorkflow,
ConcurrentWorkflow,
}
pub struct SwarmRouterConfig {
pub name: String,
pub description: String,
pub swarm_type: SwarmType,
pub agents: Vec<SwarmsAgent<OpenAI>>,
pub rules: Option<String>,
pub multi_agent_collab_prompt: bool,
}
impl Default for SwarmRouterConfig {
fn default() -> SwarmRouterConfig {
SwarmRouterConfig {
name: String::from("swarm-router"),
description: String::from("Routes your task to the desired swarm"),
swarm_type: SwarmType::SequentialWorkflow,
agents: Vec::new(),
rules: None,
multi_agent_collab_prompt: true,
}
}
}
impl SwarmRouterConfig {
fn validate(&self) -> Result<(), SwarmRouterError> {
tracing::info!("Initializing reliability checks");
self.validate_agents()?;
tracing::info!("Reliability checks completed your swarm is ready");
Ok(())
}
fn handle_rules(&mut self) {
let rules = match self.rules.as_deref() {
Some(rules) => rules,
None => return,
};
tracing::info!("Injecting rules to every agent!");
let agents = std::mem::take(&mut self.agents);
self.agents = agents
.into_iter()
.map(|agent| {
let system_prompt = agent.get_system_prompt().unwrap_or("");
let new_system_prompt = format!("{system_prompt}\n### SWARM RULES ###\n{rules}");
agent.system_prompt(new_system_prompt)
})
.collect();
tracing::info!("Finished injecting rules");
}
fn update_system_prompt_for_agent_in_swarm(&mut self) {
tracing::info!("Injecting multi-agent prompt to every agent!");
let agents = std::mem::take(&mut self.agents);
self.agents = agents
.into_iter()
.map(|agent| {
let system_prompt = agent.get_system_prompt().unwrap_or("");
let new_system_prompt = format!("{system_prompt}\n{MULTI_AGENT_COLLAB_PROMPT}");
agent.system_prompt(new_system_prompt)
})
.collect();
tracing::info!("Finished injecting multi-agent prompt");
}
fn validate_agents(&self) -> Result<(), SwarmRouterError> {
if self.agents.is_empty() {
return Err(SwarmRouterError::ValidationError(String::from(
"No agents provided for the swarm.",
)));
}
Ok(())
}
}
pub enum SwarmRouter {
SequentialWorkflow(SequentialWorkflow),
ConcurrentWorkflow(ConcurrentWorkflow),
}
impl SwarmRouter {
pub fn new_with_config(config: SwarmRouterConfig) -> Result<SwarmRouter, SwarmRouterError> {
config.validate()?;
tracing::info!(
"SwarmRouter initialized with swarm type: {:?}",
config.swarm_type
);
let mut config = config;
config.handle_rules();
if config.multi_agent_collab_prompt {
config.update_system_prompt_for_agent_in_swarm();
}
let swarm_router = SwarmRouter::create_swarm_router(config);
Ok(swarm_router)
}
pub async fn run(&self, task: &str) -> Result<AgentConversation, SwarmRouterError> {
let result = self.inner_run(task).await;
if let Err(err) = &result {
tracing::error!("Error executing task on swarm: {err}");
}
result
}
pub async fn batch_run(
&self,
tasks: Vec<String>,
) -> Result<DashMap<String, AgentConversation>, SwarmRouterError> {
let result = self.inner_batch_run(tasks).await;
if let Err(err) = &result {
tracing::error!("Error executing task on swarm: {err}");
}
result
}
async fn inner_run(&self, task: &str) -> Result<AgentConversation, SwarmRouterError> {
tracing::info!("Running task on {:?} swarm with task: {task}", self.kind());
let result = match self {
SwarmRouter::SequentialWorkflow(wf) => wf.run(task).await?,
SwarmRouter::ConcurrentWorkflow(wf) => wf.run(task).await?,
};
tracing::info!("Swarm completed successfully");
Ok(result)
}
async fn inner_batch_run(
&self,
tasks: Vec<String>,
) -> Result<DashMap<String, AgentConversation>, SwarmRouterError> {
tracing::info!("Running batch tasks on {:?} swarm", self.kind());
let result = match self {
SwarmRouter::SequentialWorkflow(wf) => {
let results = DashMap::with_capacity(tasks.len());
for task in tasks {
let result = wf.run(&task).await?;
results.insert(task, result);
}
results
},
SwarmRouter::ConcurrentWorkflow(wf) => wf.run_batch(tasks).await?,
};
tracing::info!("Swarm completed successfully");
Ok(result)
}
fn kind(&self) -> SwarmType {
match self {
SwarmRouter::SequentialWorkflow(_) => SwarmType::SequentialWorkflow,
SwarmRouter::ConcurrentWorkflow(_) => SwarmType::ConcurrentWorkflow,
}
}
fn create_swarm_router(config: SwarmRouterConfig) -> SwarmRouter {
let agents = config
.agents
.into_iter()
.map(boxed_agent)
.collect::<Vec<_>>();
match config.swarm_type {
SwarmType::SequentialWorkflow => {
let workflow = SequentialWorkflow::builder()
.name(config.name)
.description(config.description)
.agents(agents)
.build();
SwarmRouter::SequentialWorkflow(workflow)
},
SwarmType::ConcurrentWorkflow => {
let workflow = ConcurrentWorkflow::builder()
.name(config.name)
.description(config.description)
.agents(agents)
.build();
SwarmRouter::ConcurrentWorkflow(workflow)
},
}
}
}
pub async fn swarm_router(
task: &str,
config: SwarmRouterConfig,
) -> Result<AgentConversation, SwarmRouterError> {
tracing::info!(
"Creating SwarmRouter with name: {}, swarm_type: {:?}",
config.name,
config.swarm_type
);
let router = SwarmRouter::new_with_config(config)?;
tracing::info!("Executing task with SwarmRouter: {}", task);
let result = router.run(task).await?;
tracing::info!("Task execution completed successfully");
Ok(result)
}
fn boxed_agent(agent: SwarmsAgent<OpenAI>) -> Box<dyn Agent> {
Box::new(agent)
}
#[derive(Debug, thiserror::Error)]
pub enum SwarmRouterError {
#[error("SwarmRouter validation error: {0}")]
ValidationError(String),
#[error(transparent)]
SequentialWorkflowError(#[from] SequentialWorkflowError),
#[error(transparent)]
ConcurrentWorkflowError(#[from] ConcurrentWorkflowError),
}