1use arc_swap::ArcSwap;
12use notify::{Event, RecommendedWatcher, RecursiveMode, Watcher};
13use parking_lot::RwLock;
14use serde::{Deserialize, Serialize};
15use std::collections::HashMap;
16use std::fs;
17use std::path::{Path, PathBuf};
18use std::sync::Arc;
19use std::time::Duration;
20use tokio::sync::mpsc;
21use tracing::{error, info, warn};
22
23#[derive(Debug, Clone, Serialize, Deserialize)]
25pub struct AresConfig {
26 pub server: ServerConfig,
28
29 pub auth: AuthConfig,
31
32 pub database: DatabaseConfig,
34
35 #[serde(default)]
37 pub providers: HashMap<String, ProviderConfig>,
38
39 #[serde(default)]
42 pub models: HashMap<String, ModelConfig>,
43
44 #[serde(default)]
47 pub tools: HashMap<String, ToolConfig>,
48
49 #[serde(default)]
52 pub agents: HashMap<String, AgentConfig>,
53
54 #[serde(default)]
57 pub workflows: HashMap<String, WorkflowConfig>,
58
59 #[serde(default)]
61 pub rag: RagConfig,
62
63 #[cfg(feature = "skills")]
65 #[serde(default)]
66 pub skills: Option<SkillsTomlConfig>,
67
68 #[serde(default)]
70 pub config: DynamicConfigPaths,
71}
72
73#[derive(Debug, Clone, Serialize, Deserialize)]
77pub struct ServerConfig {
78 #[serde(default = "default_host")]
80 pub host: String,
81
82 #[serde(default = "default_port")]
84 pub port: u16,
85
86 #[serde(default = "default_log_level")]
88 pub log_level: String,
89
90 #[serde(default = "default_cors_origins")]
93 pub cors_origins: Vec<String>,
94
95 #[serde(default = "default_rate_limit")]
97 pub rate_limit_per_second: u32,
98
99 #[serde(default = "default_rate_limit_burst")]
101 pub rate_limit_burst: u32,
102}
103
104fn default_host() -> String {
105 "127.0.0.1".to_string()
106}
107
108fn default_port() -> u16 {
109 3000
110}
111
112fn default_log_level() -> String {
113 "info".to_string()
114}
115
116fn default_cors_origins() -> Vec<String> {
117 vec!["http://localhost:3000".to_string()]
118}
119
120fn default_rate_limit() -> u32 {
121 100
122}
123
124fn default_rate_limit_burst() -> u32 {
125 10
126}
127
128impl Default for ServerConfig {
129 fn default() -> Self {
130 Self {
131 host: default_host(),
132 port: default_port(),
133 log_level: default_log_level(),
134 cors_origins: default_cors_origins(),
135 rate_limit_per_second: default_rate_limit(),
136 rate_limit_burst: default_rate_limit_burst(),
137 }
138 }
139}
140
141#[derive(Debug, Clone, Serialize, Deserialize)]
145pub struct AuthConfig {
146 pub jwt_secret_env: String,
148
149 #[serde(default = "default_jwt_access_expiry")]
151 pub jwt_access_expiry: i64,
152
153 #[serde(default = "default_jwt_refresh_expiry")]
155 pub jwt_refresh_expiry: i64,
156
157 pub api_key_env: String,
159}
160
161fn default_jwt_access_expiry() -> i64 {
162 900
163}
164
165fn default_jwt_refresh_expiry() -> i64 {
166 604800
167}
168
169impl Default for AuthConfig {
170 fn default() -> Self {
171 Self {
172 jwt_secret_env: "JWT_SECRET".to_string(),
173 jwt_access_expiry: default_jwt_access_expiry(),
174 jwt_refresh_expiry: default_jwt_refresh_expiry(),
175 api_key_env: "API_KEY".to_string(),
176 }
177 }
178}
179
180#[derive(Debug, Clone, Serialize, Deserialize)]
184pub struct DatabaseConfig {
185 #[serde(default = "default_database_url")]
187 pub url: String,
188
189 pub qdrant: Option<QdrantConfig>,
191}
192
193fn default_database_url() -> String {
194 "postgres://postgres:postgres@localhost:5432/ares".to_string()
195}
196
197impl Default for DatabaseConfig {
198 fn default() -> Self {
199 Self {
200 url: default_database_url(),
201 qdrant: None,
202 }
203 }
204}
205
206#[derive(Debug, Clone, Serialize, Deserialize)]
208pub struct QdrantConfig {
209 #[serde(default = "default_qdrant_url")]
211 pub url: String,
212
213 pub api_key_env: Option<String>,
215}
216
217fn default_qdrant_url() -> String {
218 "http://localhost:6334".to_string()
219}
220
221impl Default for QdrantConfig {
222 fn default() -> Self {
223 Self {
224 url: default_qdrant_url(),
225 api_key_env: None,
226 }
227 }
228}
229
230#[derive(Debug, Clone, Serialize, Deserialize)]
234#[serde(tag = "type", rename_all = "lowercase")]
235pub enum ProviderConfig {
236 Ollama {
238 #[serde(default = "default_ollama_url")]
240 base_url: String,
241 default_model: String,
243 },
244 OpenAI {
246 api_key_env: String,
248 #[serde(default = "default_openai_base")]
250 api_base: String,
251 default_model: String,
253 },
254 LlamaCpp {
256 model_path: String,
258 #[serde(default = "default_n_ctx")]
260 n_ctx: u32,
261 #[serde(default = "default_n_threads")]
263 n_threads: u32,
264 #[serde(default = "default_max_tokens")]
266 max_tokens: u32,
267 },
268 Anthropic {
270 api_key_env: String,
272 default_model: String,
274 },
275}
276
277fn default_ollama_url() -> String {
278 "http://localhost:11434".to_string()
279}
280
281fn default_openai_base() -> String {
282 "https://api.openai.com/v1".to_string()
283}
284
285fn default_n_ctx() -> u32 {
286 4096
287}
288
289fn default_n_threads() -> u32 {
290 4
291}
292
293fn default_max_tokens() -> u32 {
294 512
295}
296
297#[derive(Debug, Clone, Serialize, Deserialize)]
301pub struct ModelConfig {
302 pub provider: String,
304
305 pub model: String,
307
308 #[serde(default = "default_temperature")]
310 pub temperature: f32,
311
312 #[serde(default = "default_model_max_tokens")]
314 pub max_tokens: u32,
315
316 pub top_p: Option<f32>,
318
319 pub frequency_penalty: Option<f32>,
321
322 pub presence_penalty: Option<f32>,
324}
325
326fn default_temperature() -> f32 {
327 0.7
328}
329
330fn default_model_max_tokens() -> u32 {
331 512
332}
333
334#[derive(Debug, Clone, Serialize, Deserialize)]
338pub struct ToolConfig {
339 #[serde(default = "default_true")]
341 pub enabled: bool,
342
343 #[serde(default)]
345 pub description: Option<String>,
346
347 #[serde(default = "default_tool_timeout")]
349 pub timeout_secs: u64,
350
351 #[serde(flatten)]
353 pub extra: HashMap<String, toml::Value>,
354}
355
356fn default_true() -> bool {
357 true
358}
359
360fn default_tool_timeout() -> u64 {
361 30
362}
363
364impl Default for ToolConfig {
365 fn default() -> Self {
366 Self {
367 enabled: true,
368 description: None,
369 timeout_secs: default_tool_timeout(),
370 extra: HashMap::new(),
371 }
372 }
373}
374
375#[derive(Debug, Clone, Serialize, Deserialize)]
379pub struct AgentConfig {
380 pub model: String,
382
383 #[serde(default)]
385 pub system_prompt: Option<String>,
386
387 #[serde(default)]
389 pub tools: Vec<String>,
390
391 #[serde(default = "default_max_tool_iterations")]
393 pub max_tool_iterations: usize,
394
395 #[serde(default)]
397 pub parallel_tools: bool,
398
399 #[serde(flatten)]
401 pub extra: HashMap<String, toml::Value>,
402}
403
404fn default_max_tool_iterations() -> usize {
405 10
406}
407
408#[derive(Debug, Clone, Serialize, Deserialize)]
412pub struct WorkflowConfig {
413 pub entry_agent: String,
415
416 pub fallback_agent: Option<String>,
418
419 #[serde(default = "default_max_depth")]
421 pub max_depth: u8,
422
423 #[serde(default = "default_max_iterations")]
425 pub max_iterations: u8,
426
427 #[serde(default)]
429 pub parallel_subagents: bool,
430}
431
432fn default_max_depth() -> u8 {
433 3
434}
435
436fn default_max_iterations() -> u8 {
437 5
438}
439
440#[cfg(feature = "skills")]
444#[derive(Debug, Clone, Serialize, Deserialize)]
445pub struct SkillsTomlConfig {
446 pub project_dir: Option<std::path::PathBuf>,
448 pub personal_dir: Option<std::path::PathBuf>,
450 pub plugin_dirs: Option<Vec<std::path::PathBuf>>,
452}
453
454#[derive(Debug, Clone, Serialize, Deserialize)]
458pub struct RagConfig {
459 #[serde(default = "default_vector_store")]
462 pub vector_store: String,
463
464 #[serde(default = "default_vector_path")]
466 pub vector_path: String,
467
468 #[serde(default = "default_embedding_model")]
473 pub embedding_model: String,
474
475 #[serde(default)]
477 pub sparse_embeddings: bool,
478
479 #[serde(default = "default_sparse_model")]
481 pub sparse_model: String,
482
483 #[serde(default = "default_chunking_strategy")]
486 pub chunking_strategy: String,
487
488 #[serde(default = "default_chunk_size")]
490 pub chunk_size: usize,
491
492 #[serde(default = "default_chunk_overlap")]
494 pub chunk_overlap: usize,
495
496 #[serde(default = "default_min_chunk_size")]
498 pub min_chunk_size: usize,
499
500 #[serde(default = "default_search_strategy")]
503 pub search_strategy: String,
504
505 #[serde(default = "default_search_limit")]
507 pub search_limit: usize,
508
509 #[serde(default)]
511 pub search_threshold: f32,
512
513 #[serde(default)]
515 pub hybrid_weights: Option<HybridWeightsConfig>,
516
517 #[serde(default)]
520 pub rerank_enabled: bool,
521
522 #[serde(default = "default_reranker_model")]
525 pub reranker_model: String,
526
527 #[serde(default = "default_rerank_weight")]
529 pub rerank_weight: f32,
530}
531
532#[derive(Debug, Clone, Serialize, Deserialize)]
534pub struct HybridWeightsConfig {
535 #[serde(default = "default_semantic_weight")]
537 pub semantic: f32,
538 #[serde(default = "default_bm25_weight")]
540 pub bm25: f32,
541 #[serde(default = "default_fuzzy_weight")]
543 pub fuzzy: f32,
544}
545
546impl Default for HybridWeightsConfig {
547 fn default() -> Self {
548 Self {
549 semantic: 0.5,
550 bm25: 0.3,
551 fuzzy: 0.2,
552 }
553 }
554}
555
556fn default_semantic_weight() -> f32 {
557 0.5
558}
559
560fn default_bm25_weight() -> f32 {
561 0.3
562}
563
564fn default_fuzzy_weight() -> f32 {
565 0.2
566}
567
568fn default_vector_store() -> String {
569 "ares-vector".to_string()
570}
571
572fn default_vector_path() -> String {
573 "./data/vectors".to_string()
574}
575
576fn default_embedding_model() -> String {
577 "bge-small-en-v1.5".to_string()
578}
579
580fn default_sparse_model() -> String {
581 "splade-pp-en-v1".to_string()
582}
583
584fn default_chunking_strategy() -> String {
585 "word".to_string()
586}
587
588fn default_chunk_size() -> usize {
589 200
590}
591
592fn default_chunk_overlap() -> usize {
593 50
594}
595
596fn default_min_chunk_size() -> usize {
597 20
598}
599
600fn default_search_strategy() -> String {
601 "semantic".to_string()
602}
603
604fn default_search_limit() -> usize {
605 10
606}
607
608fn default_reranker_model() -> String {
609 "bge-reranker-base".to_string()
610}
611
612fn default_rerank_weight() -> f32 {
613 0.6
614}
615
616impl Default for RagConfig {
617 fn default() -> Self {
618 Self {
619 vector_store: default_vector_store(),
620 vector_path: default_vector_path(),
621 embedding_model: default_embedding_model(),
622 sparse_embeddings: false,
623 sparse_model: default_sparse_model(),
624 chunking_strategy: default_chunking_strategy(),
625 chunk_size: default_chunk_size(),
626 chunk_overlap: default_chunk_overlap(),
627 min_chunk_size: default_min_chunk_size(),
628 search_strategy: default_search_strategy(),
629 search_limit: default_search_limit(),
630 search_threshold: 0.0,
631 hybrid_weights: None,
632 rerank_enabled: false,
633 reranker_model: default_reranker_model(),
634 rerank_weight: default_rerank_weight(),
635 }
636 }
637}
638
639#[derive(Debug, Clone, Serialize, Deserialize)]
647pub struct DynamicConfigPaths {
648 #[serde(default = "default_agents_dir")]
650 pub agents_dir: std::path::PathBuf,
651
652 #[serde(default = "default_workflows_dir")]
654 pub workflows_dir: std::path::PathBuf,
655
656 #[serde(default = "default_models_dir")]
658 pub models_dir: std::path::PathBuf,
659
660 #[serde(default = "default_tools_dir")]
662 pub tools_dir: std::path::PathBuf,
663
664 #[serde(default = "default_mcps_dir")]
666 pub mcps_dir: std::path::PathBuf,
667
668 #[serde(default = "default_hot_reload")]
670 pub hot_reload: bool,
671
672 #[serde(default = "default_watch_interval")]
674 pub watch_interval_ms: u64,
675}
676
677fn default_agents_dir() -> std::path::PathBuf {
678 std::path::PathBuf::from("config/agents")
679}
680
681fn default_workflows_dir() -> std::path::PathBuf {
682 std::path::PathBuf::from("config/workflows")
683}
684
685fn default_models_dir() -> std::path::PathBuf {
686 std::path::PathBuf::from("config/models")
687}
688
689fn default_tools_dir() -> std::path::PathBuf {
690 std::path::PathBuf::from("config/tools")
691}
692
693fn default_mcps_dir() -> std::path::PathBuf {
694 std::path::PathBuf::from("config/mcps")
695}
696
697fn default_hot_reload() -> bool {
698 true
699}
700
701fn default_watch_interval() -> u64 {
702 1000
703}
704
705impl Default for DynamicConfigPaths {
706 fn default() -> Self {
707 Self {
708 agents_dir: default_agents_dir(),
709 workflows_dir: default_workflows_dir(),
710 models_dir: default_models_dir(),
711 tools_dir: default_tools_dir(),
712 mcps_dir: default_mcps_dir(),
713 hot_reload: default_hot_reload(),
714 watch_interval_ms: default_watch_interval(),
715 }
716 }
717}
718
719#[derive(Debug, Clone)]
723pub struct ConfigWarning {
724 pub kind: ConfigWarningKind,
726
727 pub message: String,
729}
730
731#[derive(Debug, Clone, PartialEq)]
733pub enum ConfigWarningKind {
734 UnusedProvider,
736
737 UnusedModel,
739
740 UnusedTool,
742
743 UnusedAgent,
745}
746
747impl std::fmt::Display for ConfigWarning {
748 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
749 write!(f, "{}", self.message)
750 }
751}
752
753#[derive(Debug, thiserror::Error)]
755pub enum ConfigError {
756 #[error("Configuration file not found: {0}")]
758 FileNotFound(PathBuf),
759
760 #[error("Failed to read configuration file: {0}")]
762 ReadError(#[from] std::io::Error),
763
764 #[error("Failed to parse TOML: {0}")]
766 ParseError(#[from] toml::de::Error),
767
768 #[error("Validation error: {0}")]
770 ValidationError(String),
771
772 #[error("Environment variable '{0}' referenced in config is not set")]
774 MissingEnvVar(String),
775
776 #[error("Provider '{0}' referenced by model '{1}' does not exist")]
778 MissingProvider(String, String),
779
780 #[error("Model '{0}' referenced by agent '{1}' does not exist")]
782 MissingModel(String, String),
783
784 #[error("Agent '{0}' referenced by workflow '{1}' does not exist")]
786 MissingAgent(String, String),
787
788 #[error("Tool '{0}' referenced by agent '{1}' does not exist")]
790 MissingTool(String, String),
791
792 #[error("Circular reference detected: {0}")]
794 CircularReference(String),
795
796 #[error("Watch error: {0}")]
798 WatchError(#[from] notify::Error),
799}
800
801impl AresConfig {
802 pub fn load<P: AsRef<Path>>(path: P) -> Result<Self, ConfigError> {
809 let path = path.as_ref();
810
811 if !path.exists() {
812 return Err(ConfigError::FileNotFound(path.to_path_buf()));
813 }
814
815 let content = fs::read_to_string(path)?;
816 let config: AresConfig = toml::from_str(&content)?;
817
818 config.validate()?;
820
821 Ok(config)
822 }
823
824 pub fn load_unchecked<P: AsRef<Path>>(path: P) -> Result<Self, ConfigError> {
830 let path = path.as_ref();
831
832 if !path.exists() {
833 return Err(ConfigError::FileNotFound(path.to_path_buf()));
834 }
835
836 let content = fs::read_to_string(path)?;
837 let config: AresConfig = toml::from_str(&content)?;
838
839 Ok(config)
840 }
841
842 pub fn validate(&self) -> Result<(), ConfigError> {
844 self.validate_env_var(&self.auth.jwt_secret_env)?;
846 self.validate_env_var(&self.auth.api_key_env)?;
847
848 if let Some(ref qdrant) = self.database.qdrant {
850 if let Some(ref env) = qdrant.api_key_env {
851 self.validate_env_var(env)?;
852 }
853 }
854
855 for (name, provider) in &self.providers {
857 match provider {
858 ProviderConfig::OpenAI { api_key_env, .. } => {
859 self.validate_env_var(api_key_env)?;
860 }
861 ProviderConfig::Anthropic { api_key_env, .. } => {
862 self.validate_env_var(api_key_env)?;
863 }
864 ProviderConfig::LlamaCpp { model_path, .. } => {
865 if !Path::new(model_path).exists() {
867 return Err(ConfigError::ValidationError(format!(
868 "LlamaCpp model path does not exist: {} (provider: {})",
869 model_path, name
870 )));
871 }
872 }
873 ProviderConfig::Ollama { .. } => {
874 }
876 }
877 }
878
879 for (model_name, model_config) in &self.models {
881 if !self.providers.contains_key(&model_config.provider) {
882 return Err(ConfigError::MissingProvider(
883 model_config.provider.clone(),
884 model_name.clone(),
885 ));
886 }
887 }
888
889 for (agent_name, agent_config) in &self.agents {
891 if !self.models.contains_key(&agent_config.model) {
892 return Err(ConfigError::MissingModel(
893 agent_config.model.clone(),
894 agent_name.clone(),
895 ));
896 }
897
898 for tool_name in &agent_config.tools {
899 let is_known_tool = self.tools.contains_key(tool_name);
902 let is_mcp_tool = tool_name.contains('_') && {
903 let mcp_names = self.mcp_client_names();
905 mcp_names.iter().any(|mcp_name| tool_name.starts_with(&format!("{}_", mcp_name)))
906 };
907 if !is_known_tool && !is_mcp_tool {
908 return Err(ConfigError::MissingTool(
909 tool_name.clone(),
910 agent_name.clone(),
911 ));
912 }
913 }
914 }
915
916 for (workflow_name, workflow_config) in &self.workflows {
918 if !self.agents.contains_key(&workflow_config.entry_agent) {
919 return Err(ConfigError::MissingAgent(
920 workflow_config.entry_agent.clone(),
921 workflow_name.clone(),
922 ));
923 }
924
925 if let Some(ref fallback) = workflow_config.fallback_agent {
926 if !self.agents.contains_key(fallback) {
927 return Err(ConfigError::MissingAgent(
928 fallback.clone(),
929 workflow_name.clone(),
930 ));
931 }
932 }
933 }
934
935 self.detect_circular_references()?;
937
938 Ok(())
939 }
940
941 fn detect_circular_references(&self) -> Result<(), ConfigError> {
946 use std::collections::HashSet;
947
948 for (workflow_name, workflow_config) in &self.workflows {
949 let mut visited = HashSet::new();
950 let mut current = Some(workflow_config.entry_agent.as_str());
951
952 while let Some(agent_name) = current {
953 if visited.contains(agent_name) {
954 return Err(ConfigError::CircularReference(format!(
955 "Circular reference detected in workflow '{}': agent '{}' appears multiple times in the chain",
956 workflow_name, agent_name
957 )));
958 }
959 visited.insert(agent_name);
960
961 current = None;
964
965 if let Some(ref fallback) = workflow_config.fallback_agent {
967 if fallback == &workflow_config.entry_agent {
968 return Err(ConfigError::CircularReference(format!(
969 "Workflow '{}' has entry_agent '{}' that equals fallback_agent",
970 workflow_name, workflow_config.entry_agent
971 )));
972 }
973 }
974 }
975 }
976
977 Ok(())
978 }
979
980 pub fn validate_with_warnings(&self) -> Result<Vec<ConfigWarning>, ConfigError> {
984 self.validate()?;
986
987 let mut warnings = Vec::new();
989
990 warnings.extend(self.check_unused_providers());
992
993 warnings.extend(self.check_unused_models());
995
996 warnings.extend(self.check_unused_tools());
998
999 warnings.extend(self.check_unused_agents());
1001
1002 Ok(warnings)
1003 }
1004
1005 fn check_unused_providers(&self) -> Vec<ConfigWarning> {
1007 use std::collections::HashSet;
1008
1009 let referenced: HashSet<_> = self.models.values().map(|m| m.provider.as_str()).collect();
1010
1011 self.providers
1012 .keys()
1013 .filter(|name| !referenced.contains(name.as_str()))
1014 .map(|name| ConfigWarning {
1015 kind: ConfigWarningKind::UnusedProvider,
1016 message: format!(
1017 "Provider '{}' is defined but not referenced by any model",
1018 name
1019 ),
1020 })
1021 .collect()
1022 }
1023
1024 fn check_unused_models(&self) -> Vec<ConfigWarning> {
1026 use std::collections::HashSet;
1027
1028 let referenced: HashSet<_> = self.agents.values().map(|a| a.model.as_str()).collect();
1029
1030 self.models
1031 .keys()
1032 .filter(|name| !referenced.contains(name.as_str()))
1033 .map(|name| ConfigWarning {
1034 kind: ConfigWarningKind::UnusedModel,
1035 message: format!(
1036 "Model '{}' is defined but not referenced by any agent",
1037 name
1038 ),
1039 })
1040 .collect()
1041 }
1042
1043 fn check_unused_tools(&self) -> Vec<ConfigWarning> {
1045 use std::collections::HashSet;
1046
1047 let referenced: HashSet<_> = self
1048 .agents
1049 .values()
1050 .flat_map(|a| a.tools.iter().map(|t| t.as_str()))
1051 .collect();
1052
1053 self.tools
1054 .keys()
1055 .filter(|name| !referenced.contains(name.as_str()))
1056 .map(|name| ConfigWarning {
1057 kind: ConfigWarningKind::UnusedTool,
1058 message: format!("Tool '{}' is defined but not referenced by any agent", name),
1059 })
1060 .collect()
1061 }
1062
1063 fn check_unused_agents(&self) -> Vec<ConfigWarning> {
1065 use std::collections::HashSet;
1066
1067 let referenced: HashSet<_> = self
1068 .workflows
1069 .values()
1070 .flat_map(|w| {
1071 let mut refs = vec![w.entry_agent.as_str()];
1072 if let Some(ref fallback) = w.fallback_agent {
1073 refs.push(fallback.as_str());
1074 }
1075 refs
1076 })
1077 .collect();
1078
1079 let system_agents: HashSet<&str> = ["orchestrator", "router"].into_iter().collect();
1081
1082 self.agents
1083 .keys()
1084 .filter(|name| {
1085 !referenced.contains(name.as_str()) && !system_agents.contains(name.as_str())
1086 })
1087 .map(|name| ConfigWarning {
1088 kind: ConfigWarningKind::UnusedAgent,
1089 message: format!(
1090 "Agent '{}' is defined but not referenced by any workflow",
1091 name
1092 ),
1093 })
1094 .collect()
1095 }
1096
1097 fn validate_env_var(&self, name: &str) -> Result<(), ConfigError> {
1098 std::env::var(name).map_err(|_| ConfigError::MissingEnvVar(name.to_string()))?;
1099 Ok(())
1100 }
1101
1102 pub fn resolve_env(&self, env_name: &str) -> Option<String> {
1104 std::env::var(env_name).ok()
1105 }
1106
1107 const JWT_SECRET_MIN_LENGTH: usize = 32;
1109
1110 pub fn mcp_client_names(&self) -> Vec<String> {
1119 let path = &self.config.mcps_dir;
1120 if !path.exists() { return vec![]; }
1121 std::fs::read_dir(path)
1122 .ok()
1123 .map(|entries| {
1124 entries.filter_map(|e| {
1125 let e = e.ok()?;
1126 let p = e.path();
1127 if p.extension()?.to_str()? == "toon" {
1128 let content = std::fs::read_to_string(&p).ok()?;
1130 let val: toml::Value = toml::from_str(&content).ok()?;
1131 val.get("name")?.as_str().map(String::from)
1132 } else { None }
1133 }).collect()
1134 })
1135 .unwrap_or_default()
1136 }
1137
1138 pub fn jwt_secret(&self) -> Result<String, ConfigError> {
1139 let secret = self
1140 .resolve_env(&self.auth.jwt_secret_env)
1141 .ok_or_else(|| ConfigError::MissingEnvVar(self.auth.jwt_secret_env.clone()))?;
1142
1143 if secret.len() < Self::JWT_SECRET_MIN_LENGTH {
1144 return Err(ConfigError::ValidationError(format!(
1145 "JWT_SECRET must be at least {} characters for security (current: {} chars). \
1146 Use a cryptographically random string, e.g.: openssl rand -base64 32",
1147 Self::JWT_SECRET_MIN_LENGTH,
1148 secret.len()
1149 )));
1150 }
1151
1152 Ok(secret)
1153 }
1154
1155 pub fn api_key(&self) -> Result<String, ConfigError> {
1157 self.resolve_env(&self.auth.api_key_env)
1158 .ok_or_else(|| ConfigError::MissingEnvVar(self.auth.api_key_env.clone()))
1159 }
1160
1161 pub fn get_provider(&self, name: &str) -> Option<&ProviderConfig> {
1163 self.providers.get(name)
1164 }
1165
1166 pub fn get_model(&self, name: &str) -> Option<&ModelConfig> {
1168 self.models.get(name)
1169 }
1170
1171 pub fn get_agent(&self, name: &str) -> Option<&AgentConfig> {
1173 self.agents.get(name)
1174 }
1175
1176 pub fn get_tool(&self, name: &str) -> Option<&ToolConfig> {
1178 self.tools.get(name)
1179 }
1180
1181 pub fn get_workflow(&self, name: &str) -> Option<&WorkflowConfig> {
1183 self.workflows.get(name)
1184 }
1185
1186 pub fn enabled_tools(&self) -> Vec<&str> {
1188 self.tools
1189 .iter()
1190 .filter(|(_, config)| config.enabled)
1191 .map(|(name, _)| name.as_str())
1192 .collect()
1193 }
1194
1195 pub fn agent_tools(&self, agent_name: &str) -> Vec<&str> {
1197 self.get_agent(agent_name)
1198 .map(|agent| {
1199 agent
1200 .tools
1201 .iter()
1202 .filter(|t| self.get_tool(t).map(|tc| tc.enabled).unwrap_or(false))
1203 .map(|s| s.as_str())
1204 .collect()
1205 })
1206 .unwrap_or_default()
1207 }
1208}
1209
1210pub struct AresConfigManager {
1214 config: Arc<ArcSwap<AresConfig>>,
1215 config_path: PathBuf,
1216 watcher: RwLock<Option<RecommendedWatcher>>,
1217 reload_tx: Option<mpsc::UnboundedSender<()>>,
1218}
1219
1220impl AresConfigManager {
1221 pub fn new<P: AsRef<Path>>(path: P) -> Result<Self, ConfigError> {
1227 let path = path.as_ref();
1229 let path = if path.is_absolute() {
1230 path.to_path_buf()
1231 } else {
1232 std::env::current_dir()
1233 .map_err(ConfigError::ReadError)?
1234 .join(path)
1235 };
1236
1237 let config = AresConfig::load(&path)?;
1238
1239 Ok(Self {
1240 config: Arc::new(ArcSwap::from_pointee(config)),
1241 config_path: path,
1242 watcher: RwLock::new(None),
1243 reload_tx: None,
1244 })
1245 }
1246
1247 pub fn config(&self) -> Arc<AresConfig> {
1249 self.config.load_full()
1250 }
1251
1252 pub fn reload(&self) -> Result<(), ConfigError> {
1254 info!("Reloading configuration from {:?}", self.config_path);
1255
1256 let new_config = AresConfig::load(&self.config_path)?;
1257 self.config.store(Arc::new(new_config));
1258
1259 info!("Configuration reloaded successfully");
1260 Ok(())
1261 }
1262
1263 pub fn start_watching(&mut self) -> Result<(), ConfigError> {
1265 let (tx, mut rx) = mpsc::unbounded_channel::<()>();
1266 self.reload_tx = Some(tx.clone());
1267
1268 let config_path = self.config_path.clone();
1269 let config_arc = Arc::clone(&self.config);
1270
1271 let mut watcher = notify::recommended_watcher(move |res: Result<Event, notify::Error>| {
1273 match res {
1274 Ok(event) => {
1275 if event.kind.is_modify() || event.kind.is_create() {
1276 let _ = tx.send(());
1278 }
1279 }
1280 Err(e) => {
1281 error!("Config watcher error: {:?}", e);
1282 }
1283 }
1284 })?;
1285
1286 if let Some(parent) = self.config_path.parent() {
1288 watcher.watch(parent, RecursiveMode::NonRecursive)?;
1289 }
1290
1291 *self.watcher.write() = Some(watcher);
1292
1293 let config_path_clone = config_path.clone();
1295 tokio::spawn(async move {
1296 let mut last_reload = std::time::Instant::now();
1297 let debounce_duration = Duration::from_millis(500);
1298
1299 while rx.recv().await.is_some() {
1300 if last_reload.elapsed() < debounce_duration {
1302 continue;
1303 }
1304
1305 tokio::time::sleep(Duration::from_millis(100)).await;
1307
1308 match AresConfig::load(&config_path_clone) {
1309 Ok(new_config) => {
1310 config_arc.store(Arc::new(new_config));
1311 info!("Configuration hot-reloaded successfully");
1312 last_reload = std::time::Instant::now();
1313 }
1314 Err(e) => {
1315 warn!(
1316 "Failed to hot-reload config: {}. Keeping previous config.",
1317 e
1318 );
1319 }
1320 }
1321 }
1322 });
1323
1324 info!("Configuration hot-reload watcher started");
1325 Ok(())
1326 }
1327
1328 pub fn stop_watching(&self) {
1330 *self.watcher.write() = None;
1331 info!("Configuration hot-reload watcher stopped");
1332 }
1333}
1334
1335impl Clone for AresConfigManager {
1336 fn clone(&self) -> Self {
1337 Self {
1338 config: Arc::clone(&self.config),
1339 config_path: self.config_path.clone(),
1340 watcher: RwLock::new(None), reload_tx: self.reload_tx.clone(),
1342 }
1343 }
1344}
1345
1346impl AresConfigManager {
1347 pub fn from_config(config: AresConfig) -> Self {
1350 Self {
1351 config: Arc::new(ArcSwap::from_pointee(config)),
1352 config_path: PathBuf::from("test-config.toml"),
1353 watcher: RwLock::new(None),
1354 reload_tx: None,
1355 }
1356 }
1357}
1358
1359#[cfg(test)]
1360mod tests {
1361 use super::*;
1362
1363 fn create_test_config() -> String {
1364 r#"
1365[server]
1366host = "127.0.0.1"
1367port = 3000
1368log_level = "debug"
1369
1370[auth]
1371jwt_secret_env = "TEST_JWT_SECRET"
1372jwt_access_expiry = 900
1373jwt_refresh_expiry = 604800
1374api_key_env = "TEST_API_KEY"
1375
1376[database]
1377url = "./data/test.db"
1378
1379[providers.ollama-local]
1380type = "ollama"
1381base_url = "http://localhost:11434"
1382default_model = "ministral-3:3b"
1383
1384[models.default]
1385provider = "ollama-local"
1386model = "ministral-3:3b"
1387temperature = 0.7
1388max_tokens = 512
1389
1390[tools.calculator]
1391enabled = true
1392description = "Basic calculator"
1393timeout_secs = 10
1394
1395[agents.router]
1396model = "default"
1397tools = []
1398max_tool_iterations = 5
1399
1400[workflows.default]
1401entry_agent = "router"
1402max_depth = 3
1403max_iterations = 5
1404"#
1405 .to_string()
1406 }
1407
1408 #[test]
1409 fn test_parse_config() {
1410 unsafe {
1413 std::env::set_var(
1414 "TEST_JWT_SECRET",
1415 "test-secret-at-least-32-characters-long-at-least-32-characters-long",
1416 );
1417 std::env::set_var("TEST_API_KEY", "test-api-key");
1418 }
1419
1420 let content = create_test_config();
1421 let config: AresConfig = toml::from_str(&content).expect("Failed to parse config");
1422
1423 assert_eq!(config.server.host, "127.0.0.1");
1424 assert_eq!(config.server.port, 3000);
1425 assert!(config.providers.contains_key("ollama-local"));
1426 assert!(config.models.contains_key("default"));
1427 assert!(config.agents.contains_key("router"));
1428 }
1429
1430 #[test]
1431 fn test_validation_missing_provider() {
1432 unsafe {
1434 std::env::set_var("TEST_JWT_SECRET", "test-secret-at-least-32-characters-long");
1435 std::env::set_var("TEST_API_KEY", "test-key");
1436 }
1437
1438 let content = r#"
1439[server]
1440[auth]
1441jwt_secret_env = "TEST_JWT_SECRET"
1442api_key_env = "TEST_API_KEY"
1443[database]
1444[models.test]
1445provider = "nonexistent"
1446model = "test"
1447"#;
1448
1449 let config: AresConfig = toml::from_str(content).unwrap();
1450 let result = config.validate();
1451
1452 assert!(matches!(result, Err(ConfigError::MissingProvider(_, _))));
1453 }
1454
1455 #[test]
1456 fn test_validation_missing_model() {
1457 unsafe {
1459 std::env::set_var("TEST_JWT_SECRET", "test-secret-at-least-32-characters-long");
1460 std::env::set_var("TEST_API_KEY", "test-key");
1461 }
1462
1463 let content = r#"
1464[server]
1465[auth]
1466jwt_secret_env = "TEST_JWT_SECRET"
1467api_key_env = "TEST_API_KEY"
1468[database]
1469[providers.test]
1470type = "ollama"
1471default_model = "ministral-3:3b"
1472[agents.test]
1473model = "nonexistent"
1474"#;
1475
1476 let config: AresConfig = toml::from_str(content).unwrap();
1477 let result = config.validate();
1478
1479 assert!(matches!(result, Err(ConfigError::MissingModel(_, _))));
1480 }
1481
1482 #[test]
1483 fn test_validation_missing_tool() {
1484 unsafe {
1486 std::env::set_var("TEST_JWT_SECRET", "test-secret-at-least-32-characters-long");
1487 std::env::set_var("TEST_API_KEY", "test-key");
1488 }
1489
1490 let content = r#"
1491[server]
1492[auth]
1493jwt_secret_env = "TEST_JWT_SECRET"
1494api_key_env = "TEST_API_KEY"
1495[database]
1496[providers.test]
1497type = "ollama"
1498default_model = "ministral-3:3b"
1499[models.default]
1500provider = "test"
1501model = "ministral-3:3b"
1502[agents.test]
1503model = "default"
1504tools = ["nonexistent_tool"]
1505"#;
1506
1507 let config: AresConfig = toml::from_str(content).unwrap();
1508 let result = config.validate();
1509
1510 assert!(matches!(result, Err(ConfigError::MissingTool(_, _))));
1511 }
1512
1513 #[test]
1514 fn test_validation_missing_workflow_agent() {
1515 unsafe {
1517 std::env::set_var("TEST_JWT_SECRET", "test-secret-at-least-32-characters-long");
1518 std::env::set_var("TEST_API_KEY", "test-key");
1519 }
1520
1521 let content = r#"
1522[server]
1523[auth]
1524jwt_secret_env = "TEST_JWT_SECRET"
1525api_key_env = "TEST_API_KEY"
1526[database]
1527[workflows.test]
1528entry_agent = "nonexistent_agent"
1529"#;
1530
1531 let config: AresConfig = toml::from_str(content).unwrap();
1532 let result = config.validate();
1533
1534 assert!(matches!(result, Err(ConfigError::MissingAgent(_, _))));
1535 }
1536
1537 #[test]
1538 fn test_get_provider() {
1539 let content = create_test_config();
1540 let config: AresConfig = toml::from_str(&content).unwrap();
1541
1542 assert!(config.get_provider("ollama-local").is_some());
1543 assert!(config.get_provider("nonexistent").is_none());
1544 }
1545
1546 #[test]
1547 fn test_get_model() {
1548 let content = create_test_config();
1549 let config: AresConfig = toml::from_str(&content).unwrap();
1550
1551 assert!(config.get_model("default").is_some());
1552 assert!(config.get_model("nonexistent").is_none());
1553 }
1554
1555 #[test]
1556 fn test_get_agent() {
1557 let content = create_test_config();
1558 let config: AresConfig = toml::from_str(&content).unwrap();
1559
1560 assert!(config.get_agent("router").is_some());
1561 assert!(config.get_agent("nonexistent").is_none());
1562 }
1563
1564 #[test]
1565 fn test_get_tool() {
1566 let content = create_test_config();
1567 let config: AresConfig = toml::from_str(&content).unwrap();
1568
1569 assert!(config.get_tool("calculator").is_some());
1570 assert!(config.get_tool("nonexistent").is_none());
1571 }
1572
1573 #[test]
1574 fn test_enabled_tools() {
1575 let content = r#"
1576[server]
1577[auth]
1578jwt_secret_env = "TEST_JWT_SECRET"
1579api_key_env = "TEST_API_KEY"
1580[database]
1581[tools.enabled_tool]
1582enabled = true
1583[tools.disabled_tool]
1584enabled = false
1585"#;
1586
1587 let config: AresConfig = toml::from_str(content).unwrap();
1588 let enabled = config.enabled_tools();
1589
1590 assert!(enabled.contains(&"enabled_tool"));
1591 assert!(!enabled.contains(&"disabled_tool"));
1592 }
1593
1594 #[test]
1595 fn test_defaults() {
1596 let content = r#"
1597[server]
1598[auth]
1599jwt_secret_env = "TEST_JWT_SECRET"
1600api_key_env = "TEST_API_KEY"
1601[database]
1602"#;
1603
1604 let config: AresConfig = toml::from_str(content).unwrap();
1605
1606 assert_eq!(config.server.host, "127.0.0.1");
1608 assert_eq!(config.server.port, 3000);
1609 assert_eq!(config.server.log_level, "info");
1610
1611 assert_eq!(config.auth.jwt_access_expiry, 900);
1613 assert_eq!(config.auth.jwt_refresh_expiry, 604800);
1614
1615 assert_eq!(config.database.url, "postgres://postgres:postgres@localhost:5432/ares");
1617
1618 assert_eq!(config.rag.embedding_model, "bge-small-en-v1.5");
1620 assert_eq!(config.rag.chunk_size, 200);
1621 assert_eq!(config.rag.chunk_overlap, 50);
1622 assert_eq!(config.rag.vector_store, "ares-vector");
1623 assert_eq!(config.rag.search_strategy, "semantic");
1624 }
1625
1626 #[test]
1627 fn test_config_manager_from_config() {
1628 let content = create_test_config();
1629 let config: AresConfig = toml::from_str(&content).unwrap();
1630
1631 let manager = AresConfigManager::from_config(config.clone());
1632 let loaded = manager.config();
1633
1634 assert_eq!(loaded.server.host, config.server.host);
1635 assert_eq!(loaded.server.port, config.server.port);
1636 }
1637
1638 #[test]
1639 fn test_circular_reference_detection() {
1640 unsafe {
1642 std::env::set_var("TEST_JWT_SECRET", "test-secret-at-least-32-characters-long");
1643 std::env::set_var("TEST_API_KEY", "test-key");
1644 }
1645
1646 let content = r#"
1647[server]
1648[auth]
1649jwt_secret_env = "TEST_JWT_SECRET"
1650api_key_env = "TEST_API_KEY"
1651[database]
1652[providers.test]
1653type = "ollama"
1654default_model = "ministral-3:3b"
1655[models.default]
1656provider = "test"
1657model = "ministral-3:3b"
1658[agents.agent_a]
1659model = "default"
1660[workflows.circular]
1661entry_agent = "agent_a"
1662fallback_agent = "agent_a"
1663"#;
1664
1665 let config: AresConfig = toml::from_str(content).unwrap();
1666 let result = config.validate();
1667
1668 assert!(matches!(result, Err(ConfigError::CircularReference(_))));
1669 }
1670
1671 #[test]
1672 fn test_unused_provider_warning() {
1673 unsafe {
1675 std::env::set_var("TEST_JWT_SECRET", "test-secret-at-least-32-characters-long");
1676 std::env::set_var("TEST_API_KEY", "test-key");
1677 }
1678
1679 let content = r#"
1680[server]
1681[auth]
1682jwt_secret_env = "TEST_JWT_SECRET"
1683api_key_env = "TEST_API_KEY"
1684[database]
1685[providers.used]
1686type = "ollama"
1687default_model = "ministral-3:3b"
1688[providers.unused]
1689type = "ollama"
1690default_model = "ministral-3:3b"
1691[models.default]
1692provider = "used"
1693model = "ministral-3:3b"
1694[agents.router]
1695model = "default"
1696"#;
1697
1698 let config: AresConfig = toml::from_str(content).unwrap();
1699 let warnings = config.validate_with_warnings().unwrap();
1700
1701 assert!(warnings
1702 .iter()
1703 .any(|w| w.kind == ConfigWarningKind::UnusedProvider && w.message.contains("unused")));
1704 }
1705
1706 #[test]
1707 fn test_unused_model_warning() {
1708 unsafe {
1710 std::env::set_var("TEST_JWT_SECRET", "test-secret-at-least-32-characters-long");
1711 std::env::set_var("TEST_API_KEY", "test-key");
1712 }
1713
1714 let content = r#"
1715[server]
1716[auth]
1717jwt_secret_env = "TEST_JWT_SECRET"
1718api_key_env = "TEST_API_KEY"
1719[database]
1720[providers.test]
1721type = "ollama"
1722default_model = "ministral-3:3b"
1723[models.used]
1724provider = "test"
1725model = "ministral-3:3b"
1726[models.unused]
1727provider = "test"
1728model = "other"
1729[agents.router]
1730model = "used"
1731"#;
1732
1733 let config: AresConfig = toml::from_str(content).unwrap();
1734 let warnings = config.validate_with_warnings().unwrap();
1735
1736 assert!(warnings
1737 .iter()
1738 .any(|w| w.kind == ConfigWarningKind::UnusedModel && w.message.contains("unused")));
1739 }
1740
1741 #[test]
1742 fn test_unused_tool_warning() {
1743 unsafe {
1745 std::env::set_var("TEST_JWT_SECRET", "test-secret-at-least-32-characters-long");
1746 std::env::set_var("TEST_API_KEY", "test-key");
1747 }
1748
1749 let content = r#"
1750[server]
1751[auth]
1752jwt_secret_env = "TEST_JWT_SECRET"
1753api_key_env = "TEST_API_KEY"
1754[database]
1755[providers.test]
1756type = "ollama"
1757default_model = "ministral-3:3b"
1758[models.default]
1759provider = "test"
1760model = "ministral-3:3b"
1761[tools.used_tool]
1762enabled = true
1763[tools.unused_tool]
1764enabled = true
1765[agents.router]
1766model = "default"
1767tools = ["used_tool"]
1768"#;
1769
1770 let config: AresConfig = toml::from_str(content).unwrap();
1771 let warnings = config.validate_with_warnings().unwrap();
1772
1773 assert!(warnings
1774 .iter()
1775 .any(|w| w.kind == ConfigWarningKind::UnusedTool && w.message.contains("unused_tool")));
1776 }
1777
1778 #[test]
1779 fn test_unused_agent_warning() {
1780 unsafe {
1782 std::env::set_var("TEST_JWT_SECRET", "test-secret-at-least-32-characters-long");
1783 std::env::set_var("TEST_API_KEY", "test-key");
1784 }
1785
1786 let content = r#"
1787[server]
1788[auth]
1789jwt_secret_env = "TEST_JWT_SECRET"
1790api_key_env = "TEST_API_KEY"
1791[database]
1792[providers.test]
1793type = "ollama"
1794default_model = "ministral-3:3b"
1795[models.default]
1796provider = "test"
1797model = "ministral-3:3b"
1798[agents.router]
1799model = "default"
1800[agents.orphaned]
1801model = "default"
1802[workflows.test_flow]
1803entry_agent = "router"
1804"#;
1805
1806 let config: AresConfig = toml::from_str(content).unwrap();
1807 let warnings = config.validate_with_warnings().unwrap();
1808
1809 assert!(warnings
1810 .iter()
1811 .any(|w| w.kind == ConfigWarningKind::UnusedAgent && w.message.contains("orphaned")));
1812 }
1813
1814 #[test]
1815 fn test_no_warnings_for_fully_connected_config() {
1816 unsafe {
1818 std::env::set_var("TEST_JWT_SECRET", "test-secret-at-least-32-characters-long");
1819 std::env::set_var("TEST_API_KEY", "test-key");
1820 }
1821
1822 let content = r#"
1823[server]
1824[auth]
1825jwt_secret_env = "TEST_JWT_SECRET"
1826api_key_env = "TEST_API_KEY"
1827[database]
1828[providers.test]
1829type = "ollama"
1830default_model = "ministral-3:3b"
1831[models.default]
1832provider = "test"
1833model = "ministral-3:3b"
1834[tools.calc]
1835enabled = true
1836[agents.router]
1837model = "default"
1838tools = ["calc"]
1839[workflows.main]
1840entry_agent = "router"
1841"#;
1842
1843 let config: AresConfig = toml::from_str(content).unwrap();
1844 let warnings = config.validate_with_warnings().unwrap();
1845
1846 assert!(
1847 warnings.is_empty(),
1848 "Expected no warnings but got: {:?}",
1849 warnings
1850 );
1851 }
1852}