use std::collections::HashMap;
use std::sync::Arc;
use crate::{
agent::multi_agent::skills::Skill,
agent::{AgentError, UnifiedAgent},
chain::ChainError,
schemas::messages::Message,
};
pub struct SkillAgent {
agent: Arc<UnifiedAgent>,
skills: HashMap<String, Arc<dyn Skill>>,
system_prompt_template: Option<String>,
}
impl SkillAgent {
pub fn new(agent: Arc<UnifiedAgent>) -> Self {
Self {
agent,
skills: HashMap::new(),
system_prompt_template: None,
}
}
pub fn with_skill(mut self, skill: Arc<dyn Skill>) -> Self {
let skill_name = skill.name();
self.skills.insert(skill_name, skill);
self
}
pub fn with_skills(mut self, skills: Vec<Arc<dyn Skill>>) -> Self {
for skill in skills {
let skill_name = skill.name();
self.skills.insert(skill_name, skill);
}
self
}
pub fn with_system_prompt_template<S: Into<String>>(mut self, template: S) -> Self {
self.system_prompt_template = Some(template.into());
self
}
pub fn get_skill(&self, name: &str) -> Option<&Arc<dyn Skill>> {
self.skills.get(name)
}
pub async fn load_skill(
&self,
name: &str,
) -> Result<crate::agent::multi_agent::skills::SkillContext, Box<dyn std::error::Error>> {
let skill = self
.skills
.get(name)
.ok_or_else(|| format!("Skill not found: {}", name))?;
skill.load_context().await
}
pub async fn invoke_messages(&self, messages: Vec<Message>) -> Result<String, ChainError> {
let last_human_message = messages
.iter()
.rev()
.find(|m| matches!(m.message_type, crate::schemas::MessageType::HumanMessage));
let mut loaded_skills = Vec::new();
if let Some(msg) = last_human_message {
for (name, skill) in &self.skills {
if skill.should_load(&msg.content) {
match self.load_skill(name).await {
Ok(context) => {
loaded_skills.push((name.clone(), context));
}
Err(e) => {
log::warn!("Failed to load skill {}: {}", name, e);
}
}
}
}
}
let mut final_messages = messages;
if !loaded_skills.is_empty() {
let skill_contexts: Vec<String> = loaded_skills
.iter()
.map(|(name, ctx)| format!("## {}\n{}", name, ctx.content))
.collect();
let skills_text = skill_contexts.join("\n\n");
let system_message = if let Some(template) = &self.system_prompt_template {
template.replace("{skills}", &skills_text)
} else {
format!(
"You have access to the following specialized knowledge:\n\n{}",
skills_text
)
};
final_messages.insert(0, Message::new_system_message(system_message));
}
self.agent.invoke_messages(final_messages).await
}
pub fn agent(&self) -> &Arc<UnifiedAgent> {
&self.agent
}
}
pub struct SkillAgentBuilder {
agent: Option<Arc<UnifiedAgent>>,
skills: Vec<Arc<dyn Skill>>,
system_prompt_template: Option<String>,
}
impl SkillAgentBuilder {
pub fn new() -> Self {
Self {
agent: None,
skills: Vec::new(),
system_prompt_template: None,
}
}
pub fn with_agent(mut self, agent: Arc<UnifiedAgent>) -> Self {
self.agent = Some(agent);
self
}
pub fn with_skill(mut self, skill: Arc<dyn Skill>) -> Self {
self.skills.push(skill);
self
}
pub fn with_skills(mut self, skills: Vec<Arc<dyn Skill>>) -> Self {
self.skills.extend(skills);
self
}
pub fn with_system_prompt_template<S: Into<String>>(mut self, template: S) -> Self {
self.system_prompt_template = Some(template.into());
self
}
pub fn build(self) -> Result<SkillAgent, AgentError> {
let agent = self
.agent
.ok_or_else(|| AgentError::MissingObject("agent".to_string()))?;
let mut skill_agent = SkillAgent::new(agent);
skill_agent = skill_agent.with_skills(self.skills);
if let Some(template) = self.system_prompt_template {
skill_agent = skill_agent.with_system_prompt_template(template);
}
Ok(skill_agent)
}
}
impl Default for SkillAgentBuilder {
fn default() -> Self {
Self::new()
}
}