use std::collections::HashMap;
use std::sync::Arc;
use crate::{
agent::multi_agent::handoffs::HandoffTool,
agent::{AgentError, UnifiedAgent},
chain::ChainError,
schemas::messages::Message,
tools::Tool,
};
pub struct HandoffAgent {
agents: HashMap<String, Arc<UnifiedAgent>>,
default_agent: Option<Arc<UnifiedAgent>>,
handoff_tool: Arc<HandoffTool>,
}
impl HandoffAgent {
pub fn new() -> Self {
Self {
agents: HashMap::new(),
default_agent: None,
handoff_tool: Arc::new(HandoffTool::new()),
}
}
pub fn with_agent(mut self, name: String, agent: Arc<UnifiedAgent>) -> Self {
self.agents.insert(name, agent);
self
}
pub fn with_default_agent(mut self, agent: Arc<UnifiedAgent>) -> Self {
self.default_agent = Some(agent);
self
}
pub fn handoff_tool(&self) -> Arc<dyn Tool> {
self.handoff_tool.clone()
}
pub fn get_agent(&self, name: &str) -> Option<&Arc<UnifiedAgent>> {
self.agents.get(name)
}
pub fn default_agent(&self) -> Option<&Arc<UnifiedAgent>> {
self.default_agent.as_ref()
}
pub async fn invoke_messages(&self, messages: Vec<Message>) -> Result<String, ChainError> {
let _last_human_message = messages
.iter()
.rev()
.find(|m| matches!(m.message_type, crate::schemas::MessageType::HumanMessage));
let agent = self
.default_agent
.as_ref()
.or_else(|| self.agents.values().next())
.ok_or_else(|| ChainError::AgentError("No agent available for handoff".to_string()))?;
agent.invoke_messages(messages).await
}
}
impl Default for HandoffAgent {
fn default() -> Self {
Self::new()
}
}
pub struct HandoffAgentBuilder {
base_agent: Option<Arc<UnifiedAgent>>,
handoff_agents: HashMap<String, Arc<UnifiedAgent>>,
}
impl HandoffAgentBuilder {
pub fn new() -> Self {
Self {
base_agent: None,
handoff_agents: HashMap::new(),
}
}
pub fn with_base_agent(mut self, agent: Arc<UnifiedAgent>) -> Self {
self.base_agent = Some(agent);
self
}
pub fn with_handoff_agent(mut self, name: String, agent: Arc<UnifiedAgent>) -> Self {
self.handoff_agents.insert(name, agent);
self
}
pub fn build(self) -> Result<HandoffAgent, AgentError> {
let mut handoff_agent = HandoffAgent::new();
for (name, agent) in self.handoff_agents {
handoff_agent = handoff_agent.with_agent(name, agent);
}
if let Some(base) = self.base_agent {
handoff_agent = handoff_agent.with_default_agent(base);
}
Ok(handoff_agent)
}
}
impl Default for HandoffAgentBuilder {
fn default() -> Self {
Self::new()
}
}