use std::collections::HashMap;
use std::sync::Arc;
use crate::{
agent::multi_agent::router::{DefaultRouter, Router},
agent::{AgentError, UnifiedAgent},
chain::ChainError,
schemas::messages::Message,
};
pub struct RouterAgent {
agents: HashMap<String, Arc<UnifiedAgent>>,
router: Box<dyn Router>,
default_agent: Option<Arc<UnifiedAgent>>,
allow_parallel: bool,
}
impl RouterAgent {
pub fn new(router: Box<dyn Router>) -> Self {
Self {
agents: HashMap::new(),
router,
default_agent: None,
allow_parallel: false,
}
}
pub fn with_agent(mut self, name: String, agent: Arc<UnifiedAgent>) -> Self {
self.agents.insert(name, agent);
self
}
pub fn with_agents(mut self, agents: Vec<(String, Arc<UnifiedAgent>)>) -> Self {
for (name, agent) in agents {
self.agents.insert(name, agent);
}
self
}
pub fn with_default_agent(mut self, agent: Arc<UnifiedAgent>) -> Self {
self.default_agent = Some(agent);
self
}
pub fn with_parallel_execution(mut self, allow: bool) -> Self {
self.allow_parallel = allow;
self
}
pub fn get_agent(&self, name: &str) -> Option<&Arc<UnifiedAgent>> {
self.agents.get(name)
}
pub async fn invoke_messages(&self, messages: Vec<Message>) -> Result<String, ChainError> {
let last_human_message = messages
.iter()
.rev()
.find(|m| matches!(m.message_type, crate::schemas::MessageType::HumanMessage))
.ok_or_else(|| ChainError::AgentError("No human message found".to_string()))?;
let input = &last_human_message.content;
let selected_agent_name = self
.router
.route(input)
.await
.map_err(|e| ChainError::AgentError(e.to_string()))?;
if let Some(agent_name) = selected_agent_name {
let agent = self.agents.get(&agent_name).ok_or_else(|| {
ChainError::AgentError(format!("Agent not found: {}", agent_name))
})?;
agent.invoke_messages(messages).await
} else {
if let Some(default) = &self.default_agent {
default.invoke_messages(messages).await
} else {
Err(ChainError::AgentError(
"No suitable agent found and no default agent configured".to_string(),
))
}
}
}
}
pub struct RouterAgentBuilder {
router: Option<Box<dyn Router>>,
agents: Vec<(String, Arc<UnifiedAgent>)>,
default_agent: Option<Arc<UnifiedAgent>>,
allow_parallel: bool,
}
impl RouterAgentBuilder {
pub fn new() -> Self {
Self {
router: None,
agents: Vec::new(),
default_agent: None,
allow_parallel: false,
}
}
pub fn with_router(mut self, router: Box<dyn Router>) -> Self {
self.router = Some(router);
self
}
pub fn with_llm_router(
self,
llm: Box<dyn crate::language_models::llm::LLM>,
agent_descriptions: Vec<(String, String)>,
) -> Self {
let router = Box::new(DefaultRouter::with_llm(llm, agent_descriptions));
self.with_router(router)
}
pub fn with_keyword_router(
self,
keyword_map: std::collections::HashMap<String, Vec<String>>,
) -> Self {
let router = Box::new(DefaultRouter::with_keywords(keyword_map));
self.with_router(router)
}
pub fn with_agent(mut self, name: String, agent: Arc<UnifiedAgent>) -> Self {
self.agents.push((name, agent));
self
}
pub fn with_agents(mut self, agents: Vec<(String, Arc<UnifiedAgent>)>) -> Self {
self.agents.extend(agents);
self
}
pub fn with_default_agent(mut self, agent: Arc<UnifiedAgent>) -> Self {
self.default_agent = Some(agent);
self
}
pub fn with_parallel_execution(mut self, allow: bool) -> Self {
self.allow_parallel = allow;
self
}
pub fn build(self) -> Result<RouterAgent, AgentError> {
let router = self
.router
.ok_or_else(|| AgentError::MissingObject("router".to_string()))?;
let mut router_agent = RouterAgent::new(router);
router_agent = router_agent.with_agents(self.agents);
if let Some(default) = self.default_agent {
router_agent = router_agent.with_default_agent(default);
}
if self.allow_parallel {
router_agent = router_agent.with_parallel_execution(true);
}
Ok(router_agent)
}
}
impl Default for RouterAgentBuilder {
fn default() -> Self {
Self::new()
}
}