use anyhow::{Context, Result};
use std::collections::HashMap;
use std::fs;
use std::path::Path;
use std::sync::Mutex;
use crate::config::{get_agent_config, load_config, load_system_prompt, load_tools, LLMConfig};
use crate::llm::{create_llm_provider, ChatOptions, LLMProvider, Message};
use crate::policy::{ActivePolicy, PolicyEnforcer};
static POLICY_CACHE: Mutex<Vec<(String, ActivePolicy)>> = Mutex::new(Vec::new());
pub struct SekuireAgent {
pub name: String,
llm: Box<dyn LLMProvider>,
system_prompt: String,
tools: Vec<String>,
conversation_history: Vec<Message>,
max_history_messages: usize,
policy_enforcer: Option<PolicyEnforcer>,
}
impl SekuireAgent {
pub fn new(
name: String,
llm: Box<dyn LLMProvider>,
system_prompt: String,
tools: Vec<String>,
max_history_messages: Option<usize>,
policy_enforcer: Option<PolicyEnforcer>,
) -> Self {
Self {
name,
llm,
system_prompt,
tools,
conversation_history: Vec::new(),
max_history_messages: max_history_messages.unwrap_or(10),
policy_enforcer,
}
}
pub async fn chat(
&mut self,
user_message: &str,
options: Option<ChatOptions>,
) -> Result<String> {
if let Some(enforcer) = &self.policy_enforcer {
enforcer.enforce_model(self.get_model_name())?;
enforcer.enforce_rate_limit("request", 1)?;
}
let mut messages = vec![Message {
role: "system".to_string(),
content: self.system_prompt.clone(),
}];
let history_start = self
.conversation_history
.len()
.saturating_sub(self.max_history_messages);
messages.extend_from_slice(&self.conversation_history[history_start..]);
messages.push(Message {
role: "user".to_string(),
content: user_message.to_string(),
});
let response = self.llm.chat(&messages, options).await?;
self.conversation_history.push(Message {
role: "user".to_string(),
content: user_message.to_string(),
});
self.conversation_history.push(Message {
role: "assistant".to_string(),
content: response.content.clone(),
});
Ok(response.content)
}
pub fn clear_history(&mut self) {
self.conversation_history.clear();
}
pub fn get_history(&self) -> &[Message] {
&self.conversation_history
}
pub fn get_llm_provider(&self) -> &str {
self.llm.get_provider_name()
}
pub fn get_model_name(&self) -> &str {
self.llm.get_model_name()
}
pub fn get_tools(&self) -> Vec<String> {
if let Some(enforcer) = &self.policy_enforcer {
self.tools
.iter()
.filter(|t| enforcer.enforce_tool(t).is_ok())
.cloned()
.collect()
} else {
self.tools.clone()
}
}
pub fn get_policy_enforcer(&self) -> Option<&PolicyEnforcer> {
self.policy_enforcer.as_ref()
}
}
pub async fn get_agent(
agent_name: Option<&str>,
config_path: Option<&str>,
) -> Result<SekuireAgent> {
let config_file = config_path.unwrap_or("./sekuire.yml");
let config = load_config(config_file)?;
let agent_config = get_agent_config(&config, agent_name)?;
let llm_provider = create_llm_provider_from_config(&agent_config.llm).await?;
let system_prompt = load_system_prompt(&agent_config.system_prompt, Some(&".".to_string()))?;
let tools_schema = load_tools(&agent_config.tools, Some(&".".to_string()))?;
let tool_names: Vec<String> = tools_schema.tools.iter().map(|t| t.name.clone()).collect();
let max_history_messages = agent_config.memory.as_ref().and_then(|m| m.max_messages);
let policy_path = "policy.json";
let mut policy_enforcer = None;
let mut cached_policy: Option<ActivePolicy> = None;
if let Ok(cache) = POLICY_CACHE.lock() {
if let Some((_, policy)) = cache.iter().find(|(k, _)| k == policy_path) {
cached_policy = Some(policy.clone());
}
}
if let Some(policy) = cached_policy {
policy_enforcer = Some(PolicyEnforcer::new(policy));
} else {
if Path::new(policy_path).exists() {
if let Ok(content) = fs::read_to_string(policy_path) {
if let Ok(policy) = serde_json::from_str::<ActivePolicy>(&content) {
if let Ok(mut cache) = POLICY_CACHE.lock() {
cache.push((policy_path.to_string(), policy.clone()));
}
policy_enforcer = Some(PolicyEnforcer::new(policy));
}
}
}
}
Ok(SekuireAgent::new(
agent_config.name,
llm_provider,
system_prompt,
tool_names,
max_history_messages,
policy_enforcer,
))
}
async fn create_llm_provider_from_config(llm_config: &LLMConfig) -> Result<Box<dyn LLMProvider>> {
let api_key = std::env::var(&llm_config.api_key_env).with_context(|| {
format!(
"API key not found: {}. Please set this environment variable.",
llm_config.api_key_env
)
})?;
create_llm_provider(
&llm_config.provider,
api_key,
llm_config.model.clone(),
llm_config.base_url.clone(),
llm_config.temperature,
llm_config.max_tokens,
)
.await
}
pub async fn get_agents(config_path: Option<&str>) -> Result<HashMap<String, SekuireAgent>> {
let config_file = config_path.unwrap_or("./sekuire.yml");
let config = load_config(config_file)?;
let mut agents_map = HashMap::new();
if let Some(agents) = &config.agents {
for agent_name in agents.keys() {
let agent = get_agent(Some(agent_name), Some(config_file)).await?;
agents_map.insert(agent_name.clone(), agent);
}
} else if config.agent.is_some() {
let agent = get_agent(None, Some(config_file)).await?;
agents_map.insert(config.project.name.clone(), agent);
}
Ok(agents_map)
}
pub async fn get_system_prompt(
agent_name: Option<&str>,
config_path: Option<&str>,
) -> Result<String> {
let config_file = config_path.unwrap_or("./sekuire.yml");
let config = load_config(config_file)?;
let agent_config = get_agent_config(&config, agent_name)?;
let base_path = ".".to_string();
load_system_prompt(&agent_config.system_prompt, Some(&base_path))
}
pub async fn get_tools(agent_name: Option<&str>, config_path: Option<&str>) -> Result<Vec<String>> {
let config_file = config_path.unwrap_or("./sekuire.yml");
let config = load_config(config_file)?;
let agent_config = get_agent_config(&config, agent_name)?;
let base_path = ".".to_string();
let tools_schema = load_tools(&agent_config.tools, Some(&base_path))?;
Ok(tools_schema.tools.iter().map(|t| t.name.clone()).collect())
}