use arc_swap::ArcSwap;
use notify::{Event, RecommendedWatcher, RecursiveMode, Watcher};
use parking_lot::RwLock;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::fs;
use std::path::{Path, PathBuf};
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::mpsc;
use tracing::{error, info, warn};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AresConfig {
pub server: ServerConfig,
pub auth: AuthConfig,
pub database: DatabaseConfig,
#[serde(default)]
pub providers: HashMap<String, ProviderConfig>,
#[serde(default)]
pub models: HashMap<String, ModelConfig>,
#[serde(default)]
pub tools: HashMap<String, ToolConfig>,
#[serde(default)]
pub agents: HashMap<String, AgentConfig>,
#[serde(default)]
pub workflows: HashMap<String, WorkflowConfig>,
#[serde(default)]
pub rag: RagConfig,
#[cfg(feature = "skills")]
#[serde(default)]
pub skills: Option<SkillsTomlConfig>,
#[serde(default)]
pub config: DynamicConfigPaths,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ServerConfig {
#[serde(default = "default_host")]
pub host: String,
#[serde(default = "default_port")]
pub port: u16,
#[serde(default = "default_log_level")]
pub log_level: String,
#[serde(default = "default_cors_origins")]
pub cors_origins: Vec<String>,
#[serde(default = "default_rate_limit")]
pub rate_limit_per_second: u32,
#[serde(default = "default_rate_limit_burst")]
pub rate_limit_burst: u32,
}
fn default_host() -> String {
"127.0.0.1".to_string()
}
fn default_port() -> u16 {
3000
}
fn default_log_level() -> String {
"info".to_string()
}
fn default_cors_origins() -> Vec<String> {
vec!["http://localhost:3000".to_string()]
}
fn default_rate_limit() -> u32 {
100
}
fn default_rate_limit_burst() -> u32 {
10
}
impl Default for ServerConfig {
fn default() -> Self {
Self {
host: default_host(),
port: default_port(),
log_level: default_log_level(),
cors_origins: default_cors_origins(),
rate_limit_per_second: default_rate_limit(),
rate_limit_burst: default_rate_limit_burst(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AuthConfig {
pub jwt_secret_env: String,
#[serde(default = "default_jwt_access_expiry")]
pub jwt_access_expiry: i64,
#[serde(default = "default_jwt_refresh_expiry")]
pub jwt_refresh_expiry: i64,
pub api_key_env: String,
}
fn default_jwt_access_expiry() -> i64 {
900
}
fn default_jwt_refresh_expiry() -> i64 {
604800
}
impl Default for AuthConfig {
fn default() -> Self {
Self {
jwt_secret_env: "JWT_SECRET".to_string(),
jwt_access_expiry: default_jwt_access_expiry(),
jwt_refresh_expiry: default_jwt_refresh_expiry(),
api_key_env: "API_KEY".to_string(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DatabaseConfig {
#[serde(default = "default_database_url")]
pub url: String,
pub qdrant: Option<QdrantConfig>,
}
fn default_database_url() -> String {
"postgres://postgres:postgres@localhost:5432/ares".to_string()
}
impl Default for DatabaseConfig {
fn default() -> Self {
Self {
url: default_database_url(),
qdrant: None,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct QdrantConfig {
#[serde(default = "default_qdrant_url")]
pub url: String,
pub api_key_env: Option<String>,
}
fn default_qdrant_url() -> String {
"http://localhost:6334".to_string()
}
impl Default for QdrantConfig {
fn default() -> Self {
Self {
url: default_qdrant_url(),
api_key_env: None,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "lowercase")]
pub enum ProviderConfig {
Ollama {
#[serde(default = "default_ollama_url")]
base_url: String,
default_model: String,
},
OpenAI {
api_key_env: String,
#[serde(default = "default_openai_base")]
api_base: String,
default_model: String,
},
LlamaCpp {
model_path: String,
#[serde(default = "default_n_ctx")]
n_ctx: u32,
#[serde(default = "default_n_threads")]
n_threads: u32,
#[serde(default = "default_max_tokens")]
max_tokens: u32,
},
Anthropic {
api_key_env: String,
default_model: String,
},
}
fn default_ollama_url() -> String {
"http://localhost:11434".to_string()
}
fn default_openai_base() -> String {
"https://api.openai.com/v1".to_string()
}
fn default_n_ctx() -> u32 {
4096
}
fn default_n_threads() -> u32 {
4
}
fn default_max_tokens() -> u32 {
512
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelConfig {
pub provider: String,
pub model: String,
#[serde(default = "default_temperature")]
pub temperature: f32,
#[serde(default = "default_model_max_tokens")]
pub max_tokens: u32,
pub top_p: Option<f32>,
pub frequency_penalty: Option<f32>,
pub presence_penalty: Option<f32>,
}
fn default_temperature() -> f32 {
0.7
}
fn default_model_max_tokens() -> u32 {
512
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolConfig {
#[serde(default = "default_true")]
pub enabled: bool,
#[serde(default)]
pub description: Option<String>,
#[serde(default = "default_tool_timeout")]
pub timeout_secs: u64,
#[serde(flatten)]
pub extra: HashMap<String, toml::Value>,
}
fn default_true() -> bool {
true
}
fn default_tool_timeout() -> u64 {
30
}
impl Default for ToolConfig {
fn default() -> Self {
Self {
enabled: true,
description: None,
timeout_secs: default_tool_timeout(),
extra: HashMap::new(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AgentConfig {
pub model: String,
#[serde(default)]
pub system_prompt: Option<String>,
#[serde(default)]
pub tools: Vec<String>,
#[serde(default = "default_max_tool_iterations")]
pub max_tool_iterations: usize,
#[serde(default)]
pub parallel_tools: bool,
#[serde(flatten)]
pub extra: HashMap<String, toml::Value>,
}
fn default_max_tool_iterations() -> usize {
10
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct WorkflowConfig {
pub entry_agent: String,
pub fallback_agent: Option<String>,
#[serde(default = "default_max_depth")]
pub max_depth: u8,
#[serde(default = "default_max_iterations")]
pub max_iterations: u8,
#[serde(default)]
pub parallel_subagents: bool,
}
fn default_max_depth() -> u8 {
3
}
fn default_max_iterations() -> u8 {
5
}
#[cfg(feature = "skills")]
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SkillsTomlConfig {
pub project_dir: Option<std::path::PathBuf>,
pub personal_dir: Option<std::path::PathBuf>,
pub plugin_dirs: Option<Vec<std::path::PathBuf>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RagConfig {
#[serde(default = "default_vector_store")]
pub vector_store: String,
#[serde(default = "default_vector_path")]
pub vector_path: String,
#[serde(default = "default_embedding_model")]
pub embedding_model: String,
#[serde(default)]
pub sparse_embeddings: bool,
#[serde(default = "default_sparse_model")]
pub sparse_model: String,
#[serde(default = "default_chunking_strategy")]
pub chunking_strategy: String,
#[serde(default = "default_chunk_size")]
pub chunk_size: usize,
#[serde(default = "default_chunk_overlap")]
pub chunk_overlap: usize,
#[serde(default = "default_min_chunk_size")]
pub min_chunk_size: usize,
#[serde(default = "default_search_strategy")]
pub search_strategy: String,
#[serde(default = "default_search_limit")]
pub search_limit: usize,
#[serde(default)]
pub search_threshold: f32,
#[serde(default)]
pub hybrid_weights: Option<HybridWeightsConfig>,
#[serde(default)]
pub rerank_enabled: bool,
#[serde(default = "default_reranker_model")]
pub reranker_model: String,
#[serde(default = "default_rerank_weight")]
pub rerank_weight: f32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HybridWeightsConfig {
#[serde(default = "default_semantic_weight")]
pub semantic: f32,
#[serde(default = "default_bm25_weight")]
pub bm25: f32,
#[serde(default = "default_fuzzy_weight")]
pub fuzzy: f32,
}
impl Default for HybridWeightsConfig {
fn default() -> Self {
Self {
semantic: 0.5,
bm25: 0.3,
fuzzy: 0.2,
}
}
}
fn default_semantic_weight() -> f32 {
0.5
}
fn default_bm25_weight() -> f32 {
0.3
}
fn default_fuzzy_weight() -> f32 {
0.2
}
fn default_vector_store() -> String {
"ares-vector".to_string()
}
fn default_vector_path() -> String {
"./data/vectors".to_string()
}
fn default_embedding_model() -> String {
"bge-small-en-v1.5".to_string()
}
fn default_sparse_model() -> String {
"splade-pp-en-v1".to_string()
}
fn default_chunking_strategy() -> String {
"word".to_string()
}
fn default_chunk_size() -> usize {
200
}
fn default_chunk_overlap() -> usize {
50
}
fn default_min_chunk_size() -> usize {
20
}
fn default_search_strategy() -> String {
"semantic".to_string()
}
fn default_search_limit() -> usize {
10
}
fn default_reranker_model() -> String {
"bge-reranker-base".to_string()
}
fn default_rerank_weight() -> f32 {
0.6
}
impl Default for RagConfig {
fn default() -> Self {
Self {
vector_store: default_vector_store(),
vector_path: default_vector_path(),
embedding_model: default_embedding_model(),
sparse_embeddings: false,
sparse_model: default_sparse_model(),
chunking_strategy: default_chunking_strategy(),
chunk_size: default_chunk_size(),
chunk_overlap: default_chunk_overlap(),
min_chunk_size: default_min_chunk_size(),
search_strategy: default_search_strategy(),
search_limit: default_search_limit(),
search_threshold: 0.0,
hybrid_weights: None,
rerank_enabled: false,
reranker_model: default_reranker_model(),
rerank_weight: default_rerank_weight(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DynamicConfigPaths {
#[serde(default = "default_agents_dir")]
pub agents_dir: std::path::PathBuf,
#[serde(default = "default_workflows_dir")]
pub workflows_dir: std::path::PathBuf,
#[serde(default = "default_models_dir")]
pub models_dir: std::path::PathBuf,
#[serde(default = "default_tools_dir")]
pub tools_dir: std::path::PathBuf,
#[serde(default = "default_mcps_dir")]
pub mcps_dir: std::path::PathBuf,
#[serde(default = "default_hot_reload")]
pub hot_reload: bool,
#[serde(default = "default_watch_interval")]
pub watch_interval_ms: u64,
}
fn default_agents_dir() -> std::path::PathBuf {
std::path::PathBuf::from("config/agents")
}
fn default_workflows_dir() -> std::path::PathBuf {
std::path::PathBuf::from("config/workflows")
}
fn default_models_dir() -> std::path::PathBuf {
std::path::PathBuf::from("config/models")
}
fn default_tools_dir() -> std::path::PathBuf {
std::path::PathBuf::from("config/tools")
}
fn default_mcps_dir() -> std::path::PathBuf {
std::path::PathBuf::from("config/mcps")
}
fn default_hot_reload() -> bool {
true
}
fn default_watch_interval() -> u64 {
1000
}
impl Default for DynamicConfigPaths {
fn default() -> Self {
Self {
agents_dir: default_agents_dir(),
workflows_dir: default_workflows_dir(),
models_dir: default_models_dir(),
tools_dir: default_tools_dir(),
mcps_dir: default_mcps_dir(),
hot_reload: default_hot_reload(),
watch_interval_ms: default_watch_interval(),
}
}
}
#[derive(Debug, Clone)]
pub struct ConfigWarning {
pub kind: ConfigWarningKind,
pub message: String,
}
#[derive(Debug, Clone, PartialEq)]
pub enum ConfigWarningKind {
UnusedProvider,
UnusedModel,
UnusedTool,
UnusedAgent,
}
impl std::fmt::Display for ConfigWarning {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.message)
}
}
#[derive(Debug, thiserror::Error)]
pub enum ConfigError {
#[error("Configuration file not found: {0}")]
FileNotFound(PathBuf),
#[error("Failed to read configuration file: {0}")]
ReadError(#[from] std::io::Error),
#[error("Failed to parse TOML: {0}")]
ParseError(#[from] toml::de::Error),
#[error("Validation error: {0}")]
ValidationError(String),
#[error("Environment variable '{0}' referenced in config is not set")]
MissingEnvVar(String),
#[error("Provider '{0}' referenced by model '{1}' does not exist")]
MissingProvider(String, String),
#[error("Model '{0}' referenced by agent '{1}' does not exist")]
MissingModel(String, String),
#[error("Agent '{0}' referenced by workflow '{1}' does not exist")]
MissingAgent(String, String),
#[error("Tool '{0}' referenced by agent '{1}' does not exist")]
MissingTool(String, String),
#[error("Circular reference detected: {0}")]
CircularReference(String),
#[error("Watch error: {0}")]
WatchError(#[from] notify::Error),
}
impl AresConfig {
pub fn load<P: AsRef<Path>>(path: P) -> Result<Self, ConfigError> {
let path = path.as_ref();
if !path.exists() {
return Err(ConfigError::FileNotFound(path.to_path_buf()));
}
let content = fs::read_to_string(path)?;
let config: AresConfig = toml::from_str(&content)?;
config.validate()?;
Ok(config)
}
pub fn load_unchecked<P: AsRef<Path>>(path: P) -> Result<Self, ConfigError> {
let path = path.as_ref();
if !path.exists() {
return Err(ConfigError::FileNotFound(path.to_path_buf()));
}
let content = fs::read_to_string(path)?;
let config: AresConfig = toml::from_str(&content)?;
Ok(config)
}
pub fn validate(&self) -> Result<(), ConfigError> {
self.validate_env_var(&self.auth.jwt_secret_env)?;
self.validate_env_var(&self.auth.api_key_env)?;
if let Some(ref qdrant) = self.database.qdrant {
if let Some(ref env) = qdrant.api_key_env {
self.validate_env_var(env)?;
}
}
for (name, provider) in &self.providers {
match provider {
ProviderConfig::OpenAI { api_key_env, .. } => {
self.validate_env_var(api_key_env)?;
}
ProviderConfig::Anthropic { api_key_env, .. } => {
self.validate_env_var(api_key_env)?;
}
ProviderConfig::LlamaCpp { model_path, .. } => {
if !Path::new(model_path).exists() {
return Err(ConfigError::ValidationError(format!(
"LlamaCpp model path does not exist: {} (provider: {})",
model_path, name
)));
}
}
ProviderConfig::Ollama { .. } => {
}
}
}
for (model_name, model_config) in &self.models {
if !self.providers.contains_key(&model_config.provider) {
return Err(ConfigError::MissingProvider(
model_config.provider.clone(),
model_name.clone(),
));
}
}
for (agent_name, agent_config) in &self.agents {
if !self.models.contains_key(&agent_config.model) {
return Err(ConfigError::MissingModel(
agent_config.model.clone(),
agent_name.clone(),
));
}
for tool_name in &agent_config.tools {
let is_known_tool = self.tools.contains_key(tool_name);
let is_mcp_tool = tool_name.contains('_') && {
let mcp_names = self.mcp_client_names();
mcp_names.iter().any(|mcp_name| tool_name.starts_with(&format!("{}_", mcp_name)))
};
if !is_known_tool && !is_mcp_tool {
return Err(ConfigError::MissingTool(
tool_name.clone(),
agent_name.clone(),
));
}
}
}
for (workflow_name, workflow_config) in &self.workflows {
if !self.agents.contains_key(&workflow_config.entry_agent) {
return Err(ConfigError::MissingAgent(
workflow_config.entry_agent.clone(),
workflow_name.clone(),
));
}
if let Some(ref fallback) = workflow_config.fallback_agent {
if !self.agents.contains_key(fallback) {
return Err(ConfigError::MissingAgent(
fallback.clone(),
workflow_name.clone(),
));
}
}
}
self.detect_circular_references()?;
Ok(())
}
fn detect_circular_references(&self) -> Result<(), ConfigError> {
use std::collections::HashSet;
for (workflow_name, workflow_config) in &self.workflows {
let mut visited = HashSet::new();
let mut current = Some(workflow_config.entry_agent.as_str());
while let Some(agent_name) = current {
if visited.contains(agent_name) {
return Err(ConfigError::CircularReference(format!(
"Circular reference detected in workflow '{}': agent '{}' appears multiple times in the chain",
workflow_name, agent_name
)));
}
visited.insert(agent_name);
current = None;
if let Some(ref fallback) = workflow_config.fallback_agent {
if fallback == &workflow_config.entry_agent {
return Err(ConfigError::CircularReference(format!(
"Workflow '{}' has entry_agent '{}' that equals fallback_agent",
workflow_name, workflow_config.entry_agent
)));
}
}
}
}
Ok(())
}
pub fn validate_with_warnings(&self) -> Result<Vec<ConfigWarning>, ConfigError> {
self.validate()?;
let mut warnings = Vec::new();
warnings.extend(self.check_unused_providers());
warnings.extend(self.check_unused_models());
warnings.extend(self.check_unused_tools());
warnings.extend(self.check_unused_agents());
Ok(warnings)
}
fn check_unused_providers(&self) -> Vec<ConfigWarning> {
use std::collections::HashSet;
let referenced: HashSet<_> = self.models.values().map(|m| m.provider.as_str()).collect();
self.providers
.keys()
.filter(|name| !referenced.contains(name.as_str()))
.map(|name| ConfigWarning {
kind: ConfigWarningKind::UnusedProvider,
message: format!(
"Provider '{}' is defined but not referenced by any model",
name
),
})
.collect()
}
fn check_unused_models(&self) -> Vec<ConfigWarning> {
use std::collections::HashSet;
let referenced: HashSet<_> = self.agents.values().map(|a| a.model.as_str()).collect();
self.models
.keys()
.filter(|name| !referenced.contains(name.as_str()))
.map(|name| ConfigWarning {
kind: ConfigWarningKind::UnusedModel,
message: format!(
"Model '{}' is defined but not referenced by any agent",
name
),
})
.collect()
}
fn check_unused_tools(&self) -> Vec<ConfigWarning> {
use std::collections::HashSet;
let referenced: HashSet<_> = self
.agents
.values()
.flat_map(|a| a.tools.iter().map(|t| t.as_str()))
.collect();
self.tools
.keys()
.filter(|name| !referenced.contains(name.as_str()))
.map(|name| ConfigWarning {
kind: ConfigWarningKind::UnusedTool,
message: format!("Tool '{}' is defined but not referenced by any agent", name),
})
.collect()
}
fn check_unused_agents(&self) -> Vec<ConfigWarning> {
use std::collections::HashSet;
let referenced: HashSet<_> = self
.workflows
.values()
.flat_map(|w| {
let mut refs = vec![w.entry_agent.as_str()];
if let Some(ref fallback) = w.fallback_agent {
refs.push(fallback.as_str());
}
refs
})
.collect();
let system_agents: HashSet<&str> = ["orchestrator", "router"].into_iter().collect();
self.agents
.keys()
.filter(|name| {
!referenced.contains(name.as_str()) && !system_agents.contains(name.as_str())
})
.map(|name| ConfigWarning {
kind: ConfigWarningKind::UnusedAgent,
message: format!(
"Agent '{}' is defined but not referenced by any workflow",
name
),
})
.collect()
}
fn validate_env_var(&self, name: &str) -> Result<(), ConfigError> {
std::env::var(name).map_err(|_| ConfigError::MissingEnvVar(name.to_string()))?;
Ok(())
}
pub fn resolve_env(&self, env_name: &str) -> Option<String> {
std::env::var(env_name).ok()
}
const JWT_SECRET_MIN_LENGTH: usize = 32;
pub fn mcp_client_names(&self) -> Vec<String> {
let path = &self.config.mcps_dir;
if !path.exists() { return vec![]; }
std::fs::read_dir(path)
.ok()
.map(|entries| {
entries.filter_map(|e| {
let e = e.ok()?;
let p = e.path();
if p.extension()?.to_str()? == "toon" {
let content = std::fs::read_to_string(&p).ok()?;
let val: toml::Value = toml::from_str(&content).ok()?;
val.get("name")?.as_str().map(String::from)
} else { None }
}).collect()
})
.unwrap_or_default()
}
pub fn jwt_secret(&self) -> Result<String, ConfigError> {
let secret = self
.resolve_env(&self.auth.jwt_secret_env)
.ok_or_else(|| ConfigError::MissingEnvVar(self.auth.jwt_secret_env.clone()))?;
if secret.len() < Self::JWT_SECRET_MIN_LENGTH {
return Err(ConfigError::ValidationError(format!(
"JWT_SECRET must be at least {} characters for security (current: {} chars). \
Use a cryptographically random string, e.g.: openssl rand -base64 32",
Self::JWT_SECRET_MIN_LENGTH,
secret.len()
)));
}
Ok(secret)
}
pub fn api_key(&self) -> Result<String, ConfigError> {
self.resolve_env(&self.auth.api_key_env)
.ok_or_else(|| ConfigError::MissingEnvVar(self.auth.api_key_env.clone()))
}
pub fn get_provider(&self, name: &str) -> Option<&ProviderConfig> {
self.providers.get(name)
}
pub fn get_model(&self, name: &str) -> Option<&ModelConfig> {
self.models.get(name)
}
pub fn get_agent(&self, name: &str) -> Option<&AgentConfig> {
self.agents.get(name)
}
pub fn get_tool(&self, name: &str) -> Option<&ToolConfig> {
self.tools.get(name)
}
pub fn get_workflow(&self, name: &str) -> Option<&WorkflowConfig> {
self.workflows.get(name)
}
pub fn enabled_tools(&self) -> Vec<&str> {
self.tools
.iter()
.filter(|(_, config)| config.enabled)
.map(|(name, _)| name.as_str())
.collect()
}
pub fn agent_tools(&self, agent_name: &str) -> Vec<&str> {
self.get_agent(agent_name)
.map(|agent| {
agent
.tools
.iter()
.filter(|t| self.get_tool(t).map(|tc| tc.enabled).unwrap_or(false))
.map(|s| s.as_str())
.collect()
})
.unwrap_or_default()
}
}
pub struct AresConfigManager {
config: Arc<ArcSwap<AresConfig>>,
config_path: PathBuf,
watcher: RwLock<Option<RecommendedWatcher>>,
reload_tx: Option<mpsc::UnboundedSender<()>>,
}
impl AresConfigManager {
pub fn new<P: AsRef<Path>>(path: P) -> Result<Self, ConfigError> {
let path = path.as_ref();
let path = if path.is_absolute() {
path.to_path_buf()
} else {
std::env::current_dir()
.map_err(ConfigError::ReadError)?
.join(path)
};
let config = AresConfig::load(&path)?;
Ok(Self {
config: Arc::new(ArcSwap::from_pointee(config)),
config_path: path,
watcher: RwLock::new(None),
reload_tx: None,
})
}
pub fn config(&self) -> Arc<AresConfig> {
self.config.load_full()
}
pub fn reload(&self) -> Result<(), ConfigError> {
info!("Reloading configuration from {:?}", self.config_path);
let new_config = AresConfig::load(&self.config_path)?;
self.config.store(Arc::new(new_config));
info!("Configuration reloaded successfully");
Ok(())
}
pub fn start_watching(&mut self) -> Result<(), ConfigError> {
let (tx, mut rx) = mpsc::unbounded_channel::<()>();
self.reload_tx = Some(tx.clone());
let config_path = self.config_path.clone();
let config_arc = Arc::clone(&self.config);
let mut watcher = notify::recommended_watcher(move |res: Result<Event, notify::Error>| {
match res {
Ok(event) => {
if event.kind.is_modify() || event.kind.is_create() {
let _ = tx.send(());
}
}
Err(e) => {
error!("Config watcher error: {:?}", e);
}
}
})?;
if let Some(parent) = self.config_path.parent() {
watcher.watch(parent, RecursiveMode::NonRecursive)?;
}
*self.watcher.write() = Some(watcher);
let config_path_clone = config_path.clone();
tokio::spawn(async move {
let mut last_reload = std::time::Instant::now();
let debounce_duration = Duration::from_millis(500);
while rx.recv().await.is_some() {
if last_reload.elapsed() < debounce_duration {
continue;
}
tokio::time::sleep(Duration::from_millis(100)).await;
match AresConfig::load(&config_path_clone) {
Ok(new_config) => {
config_arc.store(Arc::new(new_config));
info!("Configuration hot-reloaded successfully");
last_reload = std::time::Instant::now();
}
Err(e) => {
warn!(
"Failed to hot-reload config: {}. Keeping previous config.",
e
);
}
}
}
});
info!("Configuration hot-reload watcher started");
Ok(())
}
pub fn stop_watching(&self) {
*self.watcher.write() = None;
info!("Configuration hot-reload watcher stopped");
}
}
impl Clone for AresConfigManager {
fn clone(&self) -> Self {
Self {
config: Arc::clone(&self.config),
config_path: self.config_path.clone(),
watcher: RwLock::new(None), reload_tx: self.reload_tx.clone(),
}
}
}
impl AresConfigManager {
pub fn from_config(config: AresConfig) -> Self {
Self {
config: Arc::new(ArcSwap::from_pointee(config)),
config_path: PathBuf::from("test-config.toml"),
watcher: RwLock::new(None),
reload_tx: None,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn create_test_config() -> String {
r#"
[server]
host = "127.0.0.1"
port = 3000
log_level = "debug"
[auth]
jwt_secret_env = "TEST_JWT_SECRET"
jwt_access_expiry = 900
jwt_refresh_expiry = 604800
api_key_env = "TEST_API_KEY"
[database]
url = "./data/test.db"
[providers.ollama-local]
type = "ollama"
base_url = "http://localhost:11434"
default_model = "ministral-3:3b"
[models.default]
provider = "ollama-local"
model = "ministral-3:3b"
temperature = 0.7
max_tokens = 512
[tools.calculator]
enabled = true
description = "Basic calculator"
timeout_secs = 10
[agents.router]
model = "default"
tools = []
max_tool_iterations = 5
[workflows.default]
entry_agent = "router"
max_depth = 3
max_iterations = 5
"#
.to_string()
}
#[test]
fn test_parse_config() {
unsafe {
std::env::set_var(
"TEST_JWT_SECRET",
"test-secret-at-least-32-characters-long-at-least-32-characters-long",
);
std::env::set_var("TEST_API_KEY", "test-api-key");
}
let content = create_test_config();
let config: AresConfig = toml::from_str(&content).expect("Failed to parse config");
assert_eq!(config.server.host, "127.0.0.1");
assert_eq!(config.server.port, 3000);
assert!(config.providers.contains_key("ollama-local"));
assert!(config.models.contains_key("default"));
assert!(config.agents.contains_key("router"));
}
#[test]
fn test_validation_missing_provider() {
unsafe {
std::env::set_var("TEST_JWT_SECRET", "test-secret-at-least-32-characters-long");
std::env::set_var("TEST_API_KEY", "test-key");
}
let content = r#"
[server]
[auth]
jwt_secret_env = "TEST_JWT_SECRET"
api_key_env = "TEST_API_KEY"
[database]
[models.test]
provider = "nonexistent"
model = "test"
"#;
let config: AresConfig = toml::from_str(content).unwrap();
let result = config.validate();
assert!(matches!(result, Err(ConfigError::MissingProvider(_, _))));
}
#[test]
fn test_validation_missing_model() {
unsafe {
std::env::set_var("TEST_JWT_SECRET", "test-secret-at-least-32-characters-long");
std::env::set_var("TEST_API_KEY", "test-key");
}
let content = r#"
[server]
[auth]
jwt_secret_env = "TEST_JWT_SECRET"
api_key_env = "TEST_API_KEY"
[database]
[providers.test]
type = "ollama"
default_model = "ministral-3:3b"
[agents.test]
model = "nonexistent"
"#;
let config: AresConfig = toml::from_str(content).unwrap();
let result = config.validate();
assert!(matches!(result, Err(ConfigError::MissingModel(_, _))));
}
#[test]
fn test_validation_missing_tool() {
unsafe {
std::env::set_var("TEST_JWT_SECRET", "test-secret-at-least-32-characters-long");
std::env::set_var("TEST_API_KEY", "test-key");
}
let content = r#"
[server]
[auth]
jwt_secret_env = "TEST_JWT_SECRET"
api_key_env = "TEST_API_KEY"
[database]
[providers.test]
type = "ollama"
default_model = "ministral-3:3b"
[models.default]
provider = "test"
model = "ministral-3:3b"
[agents.test]
model = "default"
tools = ["nonexistent_tool"]
"#;
let config: AresConfig = toml::from_str(content).unwrap();
let result = config.validate();
assert!(matches!(result, Err(ConfigError::MissingTool(_, _))));
}
#[test]
fn test_validation_missing_workflow_agent() {
unsafe {
std::env::set_var("TEST_JWT_SECRET", "test-secret-at-least-32-characters-long");
std::env::set_var("TEST_API_KEY", "test-key");
}
let content = r#"
[server]
[auth]
jwt_secret_env = "TEST_JWT_SECRET"
api_key_env = "TEST_API_KEY"
[database]
[workflows.test]
entry_agent = "nonexistent_agent"
"#;
let config: AresConfig = toml::from_str(content).unwrap();
let result = config.validate();
assert!(matches!(result, Err(ConfigError::MissingAgent(_, _))));
}
#[test]
fn test_get_provider() {
let content = create_test_config();
let config: AresConfig = toml::from_str(&content).unwrap();
assert!(config.get_provider("ollama-local").is_some());
assert!(config.get_provider("nonexistent").is_none());
}
#[test]
fn test_get_model() {
let content = create_test_config();
let config: AresConfig = toml::from_str(&content).unwrap();
assert!(config.get_model("default").is_some());
assert!(config.get_model("nonexistent").is_none());
}
#[test]
fn test_get_agent() {
let content = create_test_config();
let config: AresConfig = toml::from_str(&content).unwrap();
assert!(config.get_agent("router").is_some());
assert!(config.get_agent("nonexistent").is_none());
}
#[test]
fn test_get_tool() {
let content = create_test_config();
let config: AresConfig = toml::from_str(&content).unwrap();
assert!(config.get_tool("calculator").is_some());
assert!(config.get_tool("nonexistent").is_none());
}
#[test]
fn test_enabled_tools() {
let content = r#"
[server]
[auth]
jwt_secret_env = "TEST_JWT_SECRET"
api_key_env = "TEST_API_KEY"
[database]
[tools.enabled_tool]
enabled = true
[tools.disabled_tool]
enabled = false
"#;
let config: AresConfig = toml::from_str(content).unwrap();
let enabled = config.enabled_tools();
assert!(enabled.contains(&"enabled_tool"));
assert!(!enabled.contains(&"disabled_tool"));
}
#[test]
fn test_defaults() {
let content = r#"
[server]
[auth]
jwt_secret_env = "TEST_JWT_SECRET"
api_key_env = "TEST_API_KEY"
[database]
"#;
let config: AresConfig = toml::from_str(content).unwrap();
assert_eq!(config.server.host, "127.0.0.1");
assert_eq!(config.server.port, 3000);
assert_eq!(config.server.log_level, "info");
assert_eq!(config.auth.jwt_access_expiry, 900);
assert_eq!(config.auth.jwt_refresh_expiry, 604800);
assert_eq!(config.database.url, "postgres://postgres:postgres@localhost:5432/ares");
assert_eq!(config.rag.embedding_model, "bge-small-en-v1.5");
assert_eq!(config.rag.chunk_size, 200);
assert_eq!(config.rag.chunk_overlap, 50);
assert_eq!(config.rag.vector_store, "ares-vector");
assert_eq!(config.rag.search_strategy, "semantic");
}
#[test]
fn test_config_manager_from_config() {
let content = create_test_config();
let config: AresConfig = toml::from_str(&content).unwrap();
let manager = AresConfigManager::from_config(config.clone());
let loaded = manager.config();
assert_eq!(loaded.server.host, config.server.host);
assert_eq!(loaded.server.port, config.server.port);
}
#[test]
fn test_circular_reference_detection() {
unsafe {
std::env::set_var("TEST_JWT_SECRET", "test-secret-at-least-32-characters-long");
std::env::set_var("TEST_API_KEY", "test-key");
}
let content = r#"
[server]
[auth]
jwt_secret_env = "TEST_JWT_SECRET"
api_key_env = "TEST_API_KEY"
[database]
[providers.test]
type = "ollama"
default_model = "ministral-3:3b"
[models.default]
provider = "test"
model = "ministral-3:3b"
[agents.agent_a]
model = "default"
[workflows.circular]
entry_agent = "agent_a"
fallback_agent = "agent_a"
"#;
let config: AresConfig = toml::from_str(content).unwrap();
let result = config.validate();
assert!(matches!(result, Err(ConfigError::CircularReference(_))));
}
#[test]
fn test_unused_provider_warning() {
unsafe {
std::env::set_var("TEST_JWT_SECRET", "test-secret-at-least-32-characters-long");
std::env::set_var("TEST_API_KEY", "test-key");
}
let content = r#"
[server]
[auth]
jwt_secret_env = "TEST_JWT_SECRET"
api_key_env = "TEST_API_KEY"
[database]
[providers.used]
type = "ollama"
default_model = "ministral-3:3b"
[providers.unused]
type = "ollama"
default_model = "ministral-3:3b"
[models.default]
provider = "used"
model = "ministral-3:3b"
[agents.router]
model = "default"
"#;
let config: AresConfig = toml::from_str(content).unwrap();
let warnings = config.validate_with_warnings().unwrap();
assert!(warnings
.iter()
.any(|w| w.kind == ConfigWarningKind::UnusedProvider && w.message.contains("unused")));
}
#[test]
fn test_unused_model_warning() {
unsafe {
std::env::set_var("TEST_JWT_SECRET", "test-secret-at-least-32-characters-long");
std::env::set_var("TEST_API_KEY", "test-key");
}
let content = r#"
[server]
[auth]
jwt_secret_env = "TEST_JWT_SECRET"
api_key_env = "TEST_API_KEY"
[database]
[providers.test]
type = "ollama"
default_model = "ministral-3:3b"
[models.used]
provider = "test"
model = "ministral-3:3b"
[models.unused]
provider = "test"
model = "other"
[agents.router]
model = "used"
"#;
let config: AresConfig = toml::from_str(content).unwrap();
let warnings = config.validate_with_warnings().unwrap();
assert!(warnings
.iter()
.any(|w| w.kind == ConfigWarningKind::UnusedModel && w.message.contains("unused")));
}
#[test]
fn test_unused_tool_warning() {
unsafe {
std::env::set_var("TEST_JWT_SECRET", "test-secret-at-least-32-characters-long");
std::env::set_var("TEST_API_KEY", "test-key");
}
let content = r#"
[server]
[auth]
jwt_secret_env = "TEST_JWT_SECRET"
api_key_env = "TEST_API_KEY"
[database]
[providers.test]
type = "ollama"
default_model = "ministral-3:3b"
[models.default]
provider = "test"
model = "ministral-3:3b"
[tools.used_tool]
enabled = true
[tools.unused_tool]
enabled = true
[agents.router]
model = "default"
tools = ["used_tool"]
"#;
let config: AresConfig = toml::from_str(content).unwrap();
let warnings = config.validate_with_warnings().unwrap();
assert!(warnings
.iter()
.any(|w| w.kind == ConfigWarningKind::UnusedTool && w.message.contains("unused_tool")));
}
#[test]
fn test_unused_agent_warning() {
unsafe {
std::env::set_var("TEST_JWT_SECRET", "test-secret-at-least-32-characters-long");
std::env::set_var("TEST_API_KEY", "test-key");
}
let content = r#"
[server]
[auth]
jwt_secret_env = "TEST_JWT_SECRET"
api_key_env = "TEST_API_KEY"
[database]
[providers.test]
type = "ollama"
default_model = "ministral-3:3b"
[models.default]
provider = "test"
model = "ministral-3:3b"
[agents.router]
model = "default"
[agents.orphaned]
model = "default"
[workflows.test_flow]
entry_agent = "router"
"#;
let config: AresConfig = toml::from_str(content).unwrap();
let warnings = config.validate_with_warnings().unwrap();
assert!(warnings
.iter()
.any(|w| w.kind == ConfigWarningKind::UnusedAgent && w.message.contains("orphaned")));
}
#[test]
fn test_no_warnings_for_fully_connected_config() {
unsafe {
std::env::set_var("TEST_JWT_SECRET", "test-secret-at-least-32-characters-long");
std::env::set_var("TEST_API_KEY", "test-key");
}
let content = r#"
[server]
[auth]
jwt_secret_env = "TEST_JWT_SECRET"
api_key_env = "TEST_API_KEY"
[database]
[providers.test]
type = "ollama"
default_model = "ministral-3:3b"
[models.default]
provider = "test"
model = "ministral-3:3b"
[tools.calc]
enabled = true
[agents.router]
model = "default"
tools = ["calc"]
[workflows.main]
entry_agent = "router"
"#;
let config: AresConfig = toml::from_str(content).unwrap();
let warnings = config.validate_with_warnings().unwrap();
assert!(
warnings.is_empty(),
"Expected no warnings but got: {:?}",
warnings
);
}
}