use anyhow::{Context, Result};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::fs;
use std::path::{Path, PathBuf};
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct SekuireConfig {
pub project: ProjectConfig,
#[serde(skip_serializing_if = "Option::is_none")]
pub agents: Option<HashMap<String, AgentConfig>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub agent: Option<AgentConfig>,
#[serde(skip_serializing_if = "Option::is_none")]
pub llm: Option<LLMConfig>,
#[serde(skip_serializing_if = "Option::is_none")]
pub logger: Option<LoggerConfig>,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ProjectConfig {
pub name: String,
pub version: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub description: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub author: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub license: Option<String>,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct AgentConfig {
pub name: String,
pub system_prompt: String,
pub tools: String,
pub llm: LLMConfig,
#[serde(skip_serializing_if = "Option::is_none")]
pub memory: Option<MemoryConfig>,
#[serde(skip_serializing_if = "Option::is_none")]
pub compliance: Option<ComplianceConfig>,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct LLMConfig {
pub provider: String,
pub model: String,
pub api_key_env: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub temperature: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub max_tokens: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub streaming: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub base_url: Option<String>,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct MemoryConfig {
#[serde(rename = "type")]
pub memory_type: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub max_messages: Option<usize>,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ComplianceConfig {
#[serde(skip_serializing_if = "Option::is_none")]
pub framework: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub audit_logging: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub sensitive_data_detection: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub require_approval: Option<Vec<String>>,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct LoggerConfig {
#[serde(skip_serializing_if = "Option::is_none")]
pub api_base_url: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub environment: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub enabled: Option<bool>,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ToolsSchema {
pub version: String,
pub tools: Vec<ToolDefinition>,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ToolDefinition {
pub name: String,
pub description: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub category: Option<String>,
pub schema: serde_json::Value,
pub implementation: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub permissions: Option<serde_json::Value>,
#[serde(skip_serializing_if = "Option::is_none")]
pub compliance: Option<serde_json::Value>,
}
pub fn load_config<P: AsRef<Path>>(config_path: P) -> Result<SekuireConfig> {
let path = config_path.as_ref();
let content = fs::read_to_string(path)
.with_context(|| format!("Failed to read config file: {}", path.display()))?;
let mut config: SekuireConfig = serde_yaml::from_str(&content)
.with_context(|| format!("Failed to parse config file: {}", path.display()))?;
resolve_env_vars(&mut config)?;
Ok(config)
}
pub fn load_system_prompt<P: AsRef<Path>>(prompt_path: P, base_path: Option<P>) -> Result<String> {
let path = if prompt_path.as_ref().is_absolute() {
prompt_path.as_ref().to_path_buf()
} else {
base_path
.map(|p| p.as_ref().join(prompt_path.as_ref()))
.unwrap_or_else(|| PathBuf::from(".").join(prompt_path.as_ref()))
};
fs::read_to_string(&path)
.with_context(|| format!("Failed to read system prompt: {}", path.display()))
}
pub fn load_tools<P: AsRef<Path>>(tools_path: P, base_path: Option<P>) -> Result<ToolsSchema> {
let path = if tools_path.as_ref().is_absolute() {
tools_path.as_ref().to_path_buf()
} else {
base_path
.map(|p| p.as_ref().join(tools_path.as_ref()))
.unwrap_or_else(|| PathBuf::from(".").join(tools_path.as_ref()))
};
let content = fs::read_to_string(&path)
.with_context(|| format!("Failed to read tools file: {}", path.display()))?;
serde_json::from_str(&content)
.with_context(|| format!("Failed to parse tools file: {}", path.display()))
}
pub fn get_agent_config(config: &SekuireConfig, agent_name: Option<&str>) -> Result<AgentConfig> {
if let Some(agents) = &config.agents {
let name = agent_name
.or_else(|| {
agents.keys().next().map(|s| s.as_str())
})
.context("No agents defined in sekuire.yml")?;
agents
.get(name)
.cloned()
.with_context(|| format!("Agent '{}' not found in sekuire.yml", name))
}
else if let Some(agent) = &config.agent {
Ok(agent.clone())
}
else if let Some(llm) = &config.llm {
Ok(AgentConfig {
name: config.project.name.clone(),
system_prompt: "./system_prompt.md".to_string(),
tools: "./tools.json".to_string(),
llm: llm.clone(),
memory: None,
compliance: None,
})
} else {
anyhow::bail!("No agent configuration found in sekuire.yml")
}
}
fn resolve_env_vars(config: &mut SekuireConfig) -> Result<()> {
if let Some(agents) = &mut config.agents {
for agent in agents.values_mut() {
resolve_llm_env_vars(&mut agent.llm)?;
}
}
if let Some(agent) = &mut config.agent {
resolve_llm_env_vars(&mut agent.llm)?;
}
if let Some(llm) = &mut config.llm {
resolve_llm_env_vars(llm)?;
}
Ok(())
}
fn resolve_llm_env_vars(llm: &mut LLMConfig) -> Result<()> {
llm.api_key_env = resolve_env_var_string(&llm.api_key_env)?;
if let Some(base_url) = &llm.base_url {
llm.base_url = Some(resolve_env_var_string(base_url)?);
}
Ok(())
}
fn resolve_env_var_string(s: &str) -> Result<String> {
let re = regex::Regex::new(r"\$\{([^}:]+)(?::-([^}]*))?\}").unwrap();
let mut result = s.to_string();
for cap in re.captures_iter(s) {
let var_name = cap.get(1).unwrap().as_str();
let default_value = cap.get(2).map(|m| m.as_str());
let full_match = cap.get(0).unwrap().as_str();
let value = std::env::var(var_name)
.ok()
.or_else(|| default_value.map(|s| s.to_string()))
.unwrap_or_else(|| full_match.to_string());
result = result.replace(full_match, &value);
}
Ok(result)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_resolve_env_var_string() {
std::env::set_var("TEST_VAR", "test_value");
assert_eq!(resolve_env_var_string("${TEST_VAR}").unwrap(), "test_value");
assert_eq!(
resolve_env_var_string("${MISSING:-default}").unwrap(),
"default"
);
}
}