1use arc_swap::ArcSwap;
12use notify::{Event, RecommendedWatcher, RecursiveMode, Watcher};
13use parking_lot::RwLock;
14use serde::{Deserialize, Serialize};
15use std::collections::HashMap;
16use std::fs;
17use std::path::{Path, PathBuf};
18use std::sync::Arc;
19use std::time::Duration;
20use tokio::sync::mpsc;
21use tracing::{error, info, warn};
22
23#[derive(Debug, Clone, Serialize, Deserialize)]
25pub struct AresConfig {
26 pub server: ServerConfig,
28
29 pub auth: AuthConfig,
31
32 pub database: DatabaseConfig,
34
35 #[serde(default)]
37 pub providers: HashMap<String, ProviderConfig>,
38
39 #[serde(default)]
42 pub models: HashMap<String, ModelConfig>,
43
44 #[serde(default)]
47 pub tools: HashMap<String, ToolConfig>,
48
49 #[serde(default)]
52 pub agents: HashMap<String, AgentConfig>,
53
54 #[serde(default)]
57 pub workflows: HashMap<String, WorkflowConfig>,
58
59 #[serde(default)]
61 pub rag: RagConfig,
62
63 #[serde(default)]
65 pub config: DynamicConfigPaths,
66}
67
68#[derive(Debug, Clone, Serialize, Deserialize)]
72pub struct ServerConfig {
73 #[serde(default = "default_host")]
75 pub host: String,
76
77 #[serde(default = "default_port")]
79 pub port: u16,
80
81 #[serde(default = "default_log_level")]
83 pub log_level: String,
84
85 #[serde(default = "default_cors_origins")]
88 pub cors_origins: Vec<String>,
89
90 #[serde(default = "default_rate_limit")]
92 pub rate_limit_per_second: u32,
93
94 #[serde(default = "default_rate_limit_burst")]
96 pub rate_limit_burst: u32,
97}
98
99fn default_host() -> String {
100 "127.0.0.1".to_string()
101}
102
103fn default_port() -> u16 {
104 3000
105}
106
107fn default_log_level() -> String {
108 "info".to_string()
109}
110
111fn default_cors_origins() -> Vec<String> {
112 vec!["*".to_string()]
113}
114
115fn default_rate_limit() -> u32 {
116 100
117}
118
119fn default_rate_limit_burst() -> u32 {
120 10
121}
122
123impl Default for ServerConfig {
124 fn default() -> Self {
125 Self {
126 host: default_host(),
127 port: default_port(),
128 log_level: default_log_level(),
129 cors_origins: default_cors_origins(),
130 rate_limit_per_second: default_rate_limit(),
131 rate_limit_burst: default_rate_limit_burst(),
132 }
133 }
134}
135
136#[derive(Debug, Clone, Serialize, Deserialize)]
140pub struct AuthConfig {
141 pub jwt_secret_env: String,
143
144 #[serde(default = "default_jwt_access_expiry")]
146 pub jwt_access_expiry: i64,
147
148 #[serde(default = "default_jwt_refresh_expiry")]
150 pub jwt_refresh_expiry: i64,
151
152 pub api_key_env: String,
154}
155
156fn default_jwt_access_expiry() -> i64 {
157 900
158}
159
160fn default_jwt_refresh_expiry() -> i64 {
161 604800
162}
163
164impl Default for AuthConfig {
165 fn default() -> Self {
166 Self {
167 jwt_secret_env: "JWT_SECRET".to_string(),
168 jwt_access_expiry: default_jwt_access_expiry(),
169 jwt_refresh_expiry: default_jwt_refresh_expiry(),
170 api_key_env: "API_KEY".to_string(),
171 }
172 }
173}
174
175#[derive(Debug, Clone, Serialize, Deserialize)]
179pub struct DatabaseConfig {
180 #[serde(default = "default_database_url")]
182 pub url: String,
183
184 pub turso_url_env: Option<String>,
186
187 pub turso_token_env: Option<String>,
189
190 pub qdrant: Option<QdrantConfig>,
192}
193
194fn default_database_url() -> String {
195 "./data/ares.db".to_string()
196}
197
198impl Default for DatabaseConfig {
199 fn default() -> Self {
200 Self {
201 url: default_database_url(),
202 turso_url_env: None,
203 turso_token_env: None,
204 qdrant: None,
205 }
206 }
207}
208
209#[derive(Debug, Clone, Serialize, Deserialize)]
211pub struct QdrantConfig {
212 #[serde(default = "default_qdrant_url")]
214 pub url: String,
215
216 pub api_key_env: Option<String>,
218}
219
220fn default_qdrant_url() -> String {
221 "http://localhost:6334".to_string()
222}
223
224impl Default for QdrantConfig {
225 fn default() -> Self {
226 Self {
227 url: default_qdrant_url(),
228 api_key_env: None,
229 }
230 }
231}
232
233#[derive(Debug, Clone, Serialize, Deserialize)]
237#[serde(tag = "type", rename_all = "lowercase")]
238pub enum ProviderConfig {
239 Ollama {
241 #[serde(default = "default_ollama_url")]
243 base_url: String,
244 default_model: String,
246 },
247 OpenAI {
249 api_key_env: String,
251 #[serde(default = "default_openai_base")]
253 api_base: String,
254 default_model: String,
256 },
257 LlamaCpp {
259 model_path: String,
261 #[serde(default = "default_n_ctx")]
263 n_ctx: u32,
264 #[serde(default = "default_n_threads")]
266 n_threads: u32,
267 #[serde(default = "default_max_tokens")]
269 max_tokens: u32,
270 },
271 Anthropic {
273 api_key_env: String,
275 default_model: String,
277 },
278}
279
280fn default_ollama_url() -> String {
281 "http://localhost:11434".to_string()
282}
283
284fn default_openai_base() -> String {
285 "https://api.openai.com/v1".to_string()
286}
287
288fn default_n_ctx() -> u32 {
289 4096
290}
291
292fn default_n_threads() -> u32 {
293 4
294}
295
296fn default_max_tokens() -> u32 {
297 512
298}
299
300#[derive(Debug, Clone, Serialize, Deserialize)]
304pub struct ModelConfig {
305 pub provider: String,
307
308 pub model: String,
310
311 #[serde(default = "default_temperature")]
313 pub temperature: f32,
314
315 #[serde(default = "default_model_max_tokens")]
317 pub max_tokens: u32,
318
319 pub top_p: Option<f32>,
321
322 pub frequency_penalty: Option<f32>,
324
325 pub presence_penalty: Option<f32>,
327}
328
329fn default_temperature() -> f32 {
330 0.7
331}
332
333fn default_model_max_tokens() -> u32 {
334 512
335}
336
337#[derive(Debug, Clone, Serialize, Deserialize)]
341pub struct ToolConfig {
342 #[serde(default = "default_true")]
344 pub enabled: bool,
345
346 #[serde(default)]
348 pub description: Option<String>,
349
350 #[serde(default = "default_tool_timeout")]
352 pub timeout_secs: u64,
353
354 #[serde(flatten)]
356 pub extra: HashMap<String, toml::Value>,
357}
358
359fn default_true() -> bool {
360 true
361}
362
363fn default_tool_timeout() -> u64 {
364 30
365}
366
367impl Default for ToolConfig {
368 fn default() -> Self {
369 Self {
370 enabled: true,
371 description: None,
372 timeout_secs: default_tool_timeout(),
373 extra: HashMap::new(),
374 }
375 }
376}
377
378#[derive(Debug, Clone, Serialize, Deserialize)]
382pub struct AgentConfig {
383 pub model: String,
385
386 #[serde(default)]
388 pub system_prompt: Option<String>,
389
390 #[serde(default)]
392 pub tools: Vec<String>,
393
394 #[serde(default = "default_max_tool_iterations")]
396 pub max_tool_iterations: usize,
397
398 #[serde(default)]
400 pub parallel_tools: bool,
401
402 #[serde(flatten)]
404 pub extra: HashMap<String, toml::Value>,
405}
406
407fn default_max_tool_iterations() -> usize {
408 10
409}
410
411#[derive(Debug, Clone, Serialize, Deserialize)]
415pub struct WorkflowConfig {
416 pub entry_agent: String,
418
419 pub fallback_agent: Option<String>,
421
422 #[serde(default = "default_max_depth")]
424 pub max_depth: u8,
425
426 #[serde(default = "default_max_iterations")]
428 pub max_iterations: u8,
429
430 #[serde(default)]
432 pub parallel_subagents: bool,
433}
434
435fn default_max_depth() -> u8 {
436 3
437}
438
439fn default_max_iterations() -> u8 {
440 5
441}
442
443#[derive(Debug, Clone, Serialize, Deserialize)]
447pub struct RagConfig {
448 #[serde(default = "default_vector_store")]
451 pub vector_store: String,
452
453 #[serde(default = "default_vector_path")]
455 pub vector_path: String,
456
457 #[serde(default = "default_embedding_model")]
462 pub embedding_model: String,
463
464 #[serde(default)]
466 pub sparse_embeddings: bool,
467
468 #[serde(default = "default_sparse_model")]
470 pub sparse_model: String,
471
472 #[serde(default = "default_chunking_strategy")]
475 pub chunking_strategy: String,
476
477 #[serde(default = "default_chunk_size")]
479 pub chunk_size: usize,
480
481 #[serde(default = "default_chunk_overlap")]
483 pub chunk_overlap: usize,
484
485 #[serde(default = "default_min_chunk_size")]
487 pub min_chunk_size: usize,
488
489 #[serde(default = "default_search_strategy")]
492 pub search_strategy: String,
493
494 #[serde(default = "default_search_limit")]
496 pub search_limit: usize,
497
498 #[serde(default)]
500 pub search_threshold: f32,
501
502 #[serde(default)]
504 pub hybrid_weights: Option<HybridWeightsConfig>,
505
506 #[serde(default)]
509 pub rerank_enabled: bool,
510
511 #[serde(default = "default_reranker_model")]
514 pub reranker_model: String,
515
516 #[serde(default = "default_rerank_weight")]
518 pub rerank_weight: f32,
519}
520
521#[derive(Debug, Clone, Serialize, Deserialize)]
523pub struct HybridWeightsConfig {
524 #[serde(default = "default_semantic_weight")]
526 pub semantic: f32,
527 #[serde(default = "default_bm25_weight")]
529 pub bm25: f32,
530 #[serde(default = "default_fuzzy_weight")]
532 pub fuzzy: f32,
533}
534
535impl Default for HybridWeightsConfig {
536 fn default() -> Self {
537 Self {
538 semantic: 0.5,
539 bm25: 0.3,
540 fuzzy: 0.2,
541 }
542 }
543}
544
545fn default_semantic_weight() -> f32 {
546 0.5
547}
548
549fn default_bm25_weight() -> f32 {
550 0.3
551}
552
553fn default_fuzzy_weight() -> f32 {
554 0.2
555}
556
557fn default_vector_store() -> String {
558 "ares-vector".to_string()
559}
560
561fn default_vector_path() -> String {
562 "./data/vectors".to_string()
563}
564
565fn default_embedding_model() -> String {
566 "bge-small-en-v1.5".to_string()
567}
568
569fn default_sparse_model() -> String {
570 "splade-pp-en-v1".to_string()
571}
572
573fn default_chunking_strategy() -> String {
574 "word".to_string()
575}
576
577fn default_chunk_size() -> usize {
578 200
579}
580
581fn default_chunk_overlap() -> usize {
582 50
583}
584
585fn default_min_chunk_size() -> usize {
586 20
587}
588
589fn default_search_strategy() -> String {
590 "semantic".to_string()
591}
592
593fn default_search_limit() -> usize {
594 10
595}
596
597fn default_reranker_model() -> String {
598 "bge-reranker-base".to_string()
599}
600
601fn default_rerank_weight() -> f32 {
602 0.6
603}
604
605impl Default for RagConfig {
606 fn default() -> Self {
607 Self {
608 vector_store: default_vector_store(),
609 vector_path: default_vector_path(),
610 embedding_model: default_embedding_model(),
611 sparse_embeddings: false,
612 sparse_model: default_sparse_model(),
613 chunking_strategy: default_chunking_strategy(),
614 chunk_size: default_chunk_size(),
615 chunk_overlap: default_chunk_overlap(),
616 min_chunk_size: default_min_chunk_size(),
617 search_strategy: default_search_strategy(),
618 search_limit: default_search_limit(),
619 search_threshold: 0.0,
620 hybrid_weights: None,
621 rerank_enabled: false,
622 reranker_model: default_reranker_model(),
623 rerank_weight: default_rerank_weight(),
624 }
625 }
626}
627
628#[derive(Debug, Clone, Serialize, Deserialize)]
636pub struct DynamicConfigPaths {
637 #[serde(default = "default_agents_dir")]
639 pub agents_dir: std::path::PathBuf,
640
641 #[serde(default = "default_workflows_dir")]
643 pub workflows_dir: std::path::PathBuf,
644
645 #[serde(default = "default_models_dir")]
647 pub models_dir: std::path::PathBuf,
648
649 #[serde(default = "default_tools_dir")]
651 pub tools_dir: std::path::PathBuf,
652
653 #[serde(default = "default_mcps_dir")]
655 pub mcps_dir: std::path::PathBuf,
656
657 #[serde(default = "default_hot_reload")]
659 pub hot_reload: bool,
660
661 #[serde(default = "default_watch_interval")]
663 pub watch_interval_ms: u64,
664}
665
666fn default_agents_dir() -> std::path::PathBuf {
667 std::path::PathBuf::from("config/agents")
668}
669
670fn default_workflows_dir() -> std::path::PathBuf {
671 std::path::PathBuf::from("config/workflows")
672}
673
674fn default_models_dir() -> std::path::PathBuf {
675 std::path::PathBuf::from("config/models")
676}
677
678fn default_tools_dir() -> std::path::PathBuf {
679 std::path::PathBuf::from("config/tools")
680}
681
682fn default_mcps_dir() -> std::path::PathBuf {
683 std::path::PathBuf::from("config/mcps")
684}
685
686fn default_hot_reload() -> bool {
687 true
688}
689
690fn default_watch_interval() -> u64 {
691 1000
692}
693
694impl Default for DynamicConfigPaths {
695 fn default() -> Self {
696 Self {
697 agents_dir: default_agents_dir(),
698 workflows_dir: default_workflows_dir(),
699 models_dir: default_models_dir(),
700 tools_dir: default_tools_dir(),
701 mcps_dir: default_mcps_dir(),
702 hot_reload: default_hot_reload(),
703 watch_interval_ms: default_watch_interval(),
704 }
705 }
706}
707
708#[derive(Debug, Clone)]
712pub struct ConfigWarning {
713 pub kind: ConfigWarningKind,
715
716 pub message: String,
718}
719
720#[derive(Debug, Clone, PartialEq)]
722pub enum ConfigWarningKind {
723 UnusedProvider,
725
726 UnusedModel,
728
729 UnusedTool,
731
732 UnusedAgent,
734}
735
736impl std::fmt::Display for ConfigWarning {
737 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
738 write!(f, "{}", self.message)
739 }
740}
741
742#[derive(Debug, thiserror::Error)]
744pub enum ConfigError {
745 #[error("Configuration file not found: {0}")]
747 FileNotFound(PathBuf),
748
749 #[error("Failed to read configuration file: {0}")]
751 ReadError(#[from] std::io::Error),
752
753 #[error("Failed to parse TOML: {0}")]
755 ParseError(#[from] toml::de::Error),
756
757 #[error("Validation error: {0}")]
759 ValidationError(String),
760
761 #[error("Environment variable '{0}' referenced in config is not set")]
763 MissingEnvVar(String),
764
765 #[error("Provider '{0}' referenced by model '{1}' does not exist")]
767 MissingProvider(String, String),
768
769 #[error("Model '{0}' referenced by agent '{1}' does not exist")]
771 MissingModel(String, String),
772
773 #[error("Agent '{0}' referenced by workflow '{1}' does not exist")]
775 MissingAgent(String, String),
776
777 #[error("Tool '{0}' referenced by agent '{1}' does not exist")]
779 MissingTool(String, String),
780
781 #[error("Circular reference detected: {0}")]
783 CircularReference(String),
784
785 #[error("Watch error: {0}")]
787 WatchError(#[from] notify::Error),
788}
789
790impl AresConfig {
791 pub fn load<P: AsRef<Path>>(path: P) -> Result<Self, ConfigError> {
798 let path = path.as_ref();
799
800 if !path.exists() {
801 return Err(ConfigError::FileNotFound(path.to_path_buf()));
802 }
803
804 let content = fs::read_to_string(path)?;
805 let config: AresConfig = toml::from_str(&content)?;
806
807 config.validate()?;
809
810 Ok(config)
811 }
812
813 pub fn load_unchecked<P: AsRef<Path>>(path: P) -> Result<Self, ConfigError> {
819 let path = path.as_ref();
820
821 if !path.exists() {
822 return Err(ConfigError::FileNotFound(path.to_path_buf()));
823 }
824
825 let content = fs::read_to_string(path)?;
826 let config: AresConfig = toml::from_str(&content)?;
827
828 Ok(config)
829 }
830
831 pub fn validate(&self) -> Result<(), ConfigError> {
833 self.validate_env_var(&self.auth.jwt_secret_env)?;
835 self.validate_env_var(&self.auth.api_key_env)?;
836
837 if let Some(ref env) = self.database.turso_url_env {
839 self.validate_env_var(env)?;
840 }
841 if let Some(ref env) = self.database.turso_token_env {
842 self.validate_env_var(env)?;
843 }
844 if let Some(ref qdrant) = self.database.qdrant {
845 if let Some(ref env) = qdrant.api_key_env {
846 self.validate_env_var(env)?;
847 }
848 }
849
850 for (name, provider) in &self.providers {
852 match provider {
853 ProviderConfig::OpenAI { api_key_env, .. } => {
854 self.validate_env_var(api_key_env)?;
855 }
856 ProviderConfig::Anthropic { api_key_env, .. } => {
857 self.validate_env_var(api_key_env)?;
858 }
859 ProviderConfig::LlamaCpp { model_path, .. } => {
860 if !Path::new(model_path).exists() {
862 return Err(ConfigError::ValidationError(format!(
863 "LlamaCpp model path does not exist: {} (provider: {})",
864 model_path, name
865 )));
866 }
867 }
868 ProviderConfig::Ollama { .. } => {
869 }
871 }
872 }
873
874 for (model_name, model_config) in &self.models {
876 if !self.providers.contains_key(&model_config.provider) {
877 return Err(ConfigError::MissingProvider(
878 model_config.provider.clone(),
879 model_name.clone(),
880 ));
881 }
882 }
883
884 for (agent_name, agent_config) in &self.agents {
886 if !self.models.contains_key(&agent_config.model) {
887 return Err(ConfigError::MissingModel(
888 agent_config.model.clone(),
889 agent_name.clone(),
890 ));
891 }
892
893 for tool_name in &agent_config.tools {
894 if !self.tools.contains_key(tool_name) {
895 return Err(ConfigError::MissingTool(
896 tool_name.clone(),
897 agent_name.clone(),
898 ));
899 }
900 }
901 }
902
903 for (workflow_name, workflow_config) in &self.workflows {
905 if !self.agents.contains_key(&workflow_config.entry_agent) {
906 return Err(ConfigError::MissingAgent(
907 workflow_config.entry_agent.clone(),
908 workflow_name.clone(),
909 ));
910 }
911
912 if let Some(ref fallback) = workflow_config.fallback_agent {
913 if !self.agents.contains_key(fallback) {
914 return Err(ConfigError::MissingAgent(
915 fallback.clone(),
916 workflow_name.clone(),
917 ));
918 }
919 }
920 }
921
922 self.detect_circular_references()?;
924
925 Ok(())
926 }
927
928 fn detect_circular_references(&self) -> Result<(), ConfigError> {
933 use std::collections::HashSet;
934
935 for (workflow_name, workflow_config) in &self.workflows {
936 let mut visited = HashSet::new();
937 let mut current = Some(workflow_config.entry_agent.as_str());
938
939 while let Some(agent_name) = current {
940 if visited.contains(agent_name) {
941 return Err(ConfigError::CircularReference(format!(
942 "Circular reference detected in workflow '{}': agent '{}' appears multiple times in the chain",
943 workflow_name, agent_name
944 )));
945 }
946 visited.insert(agent_name);
947
948 current = None;
951
952 if let Some(ref fallback) = workflow_config.fallback_agent {
954 if fallback == &workflow_config.entry_agent {
955 return Err(ConfigError::CircularReference(format!(
956 "Workflow '{}' has entry_agent '{}' that equals fallback_agent",
957 workflow_name, workflow_config.entry_agent
958 )));
959 }
960 }
961 }
962 }
963
964 Ok(())
965 }
966
967 pub fn validate_with_warnings(&self) -> Result<Vec<ConfigWarning>, ConfigError> {
971 self.validate()?;
973
974 let mut warnings = Vec::new();
976
977 warnings.extend(self.check_unused_providers());
979
980 warnings.extend(self.check_unused_models());
982
983 warnings.extend(self.check_unused_tools());
985
986 warnings.extend(self.check_unused_agents());
988
989 Ok(warnings)
990 }
991
992 fn check_unused_providers(&self) -> Vec<ConfigWarning> {
994 use std::collections::HashSet;
995
996 let referenced: HashSet<_> = self.models.values().map(|m| m.provider.as_str()).collect();
997
998 self.providers
999 .keys()
1000 .filter(|name| !referenced.contains(name.as_str()))
1001 .map(|name| ConfigWarning {
1002 kind: ConfigWarningKind::UnusedProvider,
1003 message: format!(
1004 "Provider '{}' is defined but not referenced by any model",
1005 name
1006 ),
1007 })
1008 .collect()
1009 }
1010
1011 fn check_unused_models(&self) -> Vec<ConfigWarning> {
1013 use std::collections::HashSet;
1014
1015 let referenced: HashSet<_> = self.agents.values().map(|a| a.model.as_str()).collect();
1016
1017 self.models
1018 .keys()
1019 .filter(|name| !referenced.contains(name.as_str()))
1020 .map(|name| ConfigWarning {
1021 kind: ConfigWarningKind::UnusedModel,
1022 message: format!(
1023 "Model '{}' is defined but not referenced by any agent",
1024 name
1025 ),
1026 })
1027 .collect()
1028 }
1029
1030 fn check_unused_tools(&self) -> Vec<ConfigWarning> {
1032 use std::collections::HashSet;
1033
1034 let referenced: HashSet<_> = self
1035 .agents
1036 .values()
1037 .flat_map(|a| a.tools.iter().map(|t| t.as_str()))
1038 .collect();
1039
1040 self.tools
1041 .keys()
1042 .filter(|name| !referenced.contains(name.as_str()))
1043 .map(|name| ConfigWarning {
1044 kind: ConfigWarningKind::UnusedTool,
1045 message: format!("Tool '{}' is defined but not referenced by any agent", name),
1046 })
1047 .collect()
1048 }
1049
1050 fn check_unused_agents(&self) -> Vec<ConfigWarning> {
1052 use std::collections::HashSet;
1053
1054 let referenced: HashSet<_> = self
1055 .workflows
1056 .values()
1057 .flat_map(|w| {
1058 let mut refs = vec![w.entry_agent.as_str()];
1059 if let Some(ref fallback) = w.fallback_agent {
1060 refs.push(fallback.as_str());
1061 }
1062 refs
1063 })
1064 .collect();
1065
1066 let system_agents: HashSet<&str> = ["orchestrator", "router"].into_iter().collect();
1068
1069 self.agents
1070 .keys()
1071 .filter(|name| {
1072 !referenced.contains(name.as_str()) && !system_agents.contains(name.as_str())
1073 })
1074 .map(|name| ConfigWarning {
1075 kind: ConfigWarningKind::UnusedAgent,
1076 message: format!(
1077 "Agent '{}' is defined but not referenced by any workflow",
1078 name
1079 ),
1080 })
1081 .collect()
1082 }
1083
1084 fn validate_env_var(&self, name: &str) -> Result<(), ConfigError> {
1085 std::env::var(name).map_err(|_| ConfigError::MissingEnvVar(name.to_string()))?;
1086 Ok(())
1087 }
1088
1089 pub fn resolve_env(&self, env_name: &str) -> Option<String> {
1091 std::env::var(env_name).ok()
1092 }
1093
1094 const JWT_SECRET_MIN_LENGTH: usize = 32;
1096
1097 pub fn jwt_secret(&self) -> Result<String, ConfigError> {
1104 let secret = self
1105 .resolve_env(&self.auth.jwt_secret_env)
1106 .ok_or_else(|| ConfigError::MissingEnvVar(self.auth.jwt_secret_env.clone()))?;
1107
1108 if secret.len() < Self::JWT_SECRET_MIN_LENGTH {
1109 return Err(ConfigError::ValidationError(format!(
1110 "JWT_SECRET must be at least {} characters for security (current: {} chars). \
1111 Use a cryptographically random string, e.g.: openssl rand -base64 32",
1112 Self::JWT_SECRET_MIN_LENGTH,
1113 secret.len()
1114 )));
1115 }
1116
1117 Ok(secret)
1118 }
1119
1120 pub fn api_key(&self) -> Result<String, ConfigError> {
1122 self.resolve_env(&self.auth.api_key_env)
1123 .ok_or_else(|| ConfigError::MissingEnvVar(self.auth.api_key_env.clone()))
1124 }
1125
1126 pub fn get_provider(&self, name: &str) -> Option<&ProviderConfig> {
1128 self.providers.get(name)
1129 }
1130
1131 pub fn get_model(&self, name: &str) -> Option<&ModelConfig> {
1133 self.models.get(name)
1134 }
1135
1136 pub fn get_agent(&self, name: &str) -> Option<&AgentConfig> {
1138 self.agents.get(name)
1139 }
1140
1141 pub fn get_tool(&self, name: &str) -> Option<&ToolConfig> {
1143 self.tools.get(name)
1144 }
1145
1146 pub fn get_workflow(&self, name: &str) -> Option<&WorkflowConfig> {
1148 self.workflows.get(name)
1149 }
1150
1151 pub fn enabled_tools(&self) -> Vec<&str> {
1153 self.tools
1154 .iter()
1155 .filter(|(_, config)| config.enabled)
1156 .map(|(name, _)| name.as_str())
1157 .collect()
1158 }
1159
1160 pub fn agent_tools(&self, agent_name: &str) -> Vec<&str> {
1162 self.get_agent(agent_name)
1163 .map(|agent| {
1164 agent
1165 .tools
1166 .iter()
1167 .filter(|t| self.get_tool(t).map(|tc| tc.enabled).unwrap_or(false))
1168 .map(|s| s.as_str())
1169 .collect()
1170 })
1171 .unwrap_or_default()
1172 }
1173}
1174
1175pub struct AresConfigManager {
1179 config: Arc<ArcSwap<AresConfig>>,
1180 config_path: PathBuf,
1181 watcher: RwLock<Option<RecommendedWatcher>>,
1182 reload_tx: Option<mpsc::UnboundedSender<()>>,
1183}
1184
1185impl AresConfigManager {
1186 pub fn new<P: AsRef<Path>>(path: P) -> Result<Self, ConfigError> {
1192 let path = path.as_ref();
1194 let path = if path.is_absolute() {
1195 path.to_path_buf()
1196 } else {
1197 std::env::current_dir()
1198 .map_err(ConfigError::ReadError)?
1199 .join(path)
1200 };
1201
1202 let config = AresConfig::load(&path)?;
1203
1204 Ok(Self {
1205 config: Arc::new(ArcSwap::from_pointee(config)),
1206 config_path: path,
1207 watcher: RwLock::new(None),
1208 reload_tx: None,
1209 })
1210 }
1211
1212 pub fn config(&self) -> Arc<AresConfig> {
1214 self.config.load_full()
1215 }
1216
1217 pub fn reload(&self) -> Result<(), ConfigError> {
1219 info!("Reloading configuration from {:?}", self.config_path);
1220
1221 let new_config = AresConfig::load(&self.config_path)?;
1222 self.config.store(Arc::new(new_config));
1223
1224 info!("Configuration reloaded successfully");
1225 Ok(())
1226 }
1227
1228 pub fn start_watching(&mut self) -> Result<(), ConfigError> {
1230 let (tx, mut rx) = mpsc::unbounded_channel::<()>();
1231 self.reload_tx = Some(tx.clone());
1232
1233 let config_path = self.config_path.clone();
1234 let config_arc = Arc::clone(&self.config);
1235
1236 let mut watcher = notify::recommended_watcher(move |res: Result<Event, notify::Error>| {
1238 match res {
1239 Ok(event) => {
1240 if event.kind.is_modify() || event.kind.is_create() {
1241 let _ = tx.send(());
1243 }
1244 }
1245 Err(e) => {
1246 error!("Config watcher error: {:?}", e);
1247 }
1248 }
1249 })?;
1250
1251 if let Some(parent) = self.config_path.parent() {
1253 watcher.watch(parent, RecursiveMode::NonRecursive)?;
1254 }
1255
1256 *self.watcher.write() = Some(watcher);
1257
1258 let config_path_clone = config_path.clone();
1260 tokio::spawn(async move {
1261 let mut last_reload = std::time::Instant::now();
1262 let debounce_duration = Duration::from_millis(500);
1263
1264 while rx.recv().await.is_some() {
1265 if last_reload.elapsed() < debounce_duration {
1267 continue;
1268 }
1269
1270 tokio::time::sleep(Duration::from_millis(100)).await;
1272
1273 match AresConfig::load(&config_path_clone) {
1274 Ok(new_config) => {
1275 config_arc.store(Arc::new(new_config));
1276 info!("Configuration hot-reloaded successfully");
1277 last_reload = std::time::Instant::now();
1278 }
1279 Err(e) => {
1280 warn!(
1281 "Failed to hot-reload config: {}. Keeping previous config.",
1282 e
1283 );
1284 }
1285 }
1286 }
1287 });
1288
1289 info!("Configuration hot-reload watcher started");
1290 Ok(())
1291 }
1292
1293 pub fn stop_watching(&self) {
1295 *self.watcher.write() = None;
1296 info!("Configuration hot-reload watcher stopped");
1297 }
1298}
1299
1300impl Clone for AresConfigManager {
1301 fn clone(&self) -> Self {
1302 Self {
1303 config: Arc::clone(&self.config),
1304 config_path: self.config_path.clone(),
1305 watcher: RwLock::new(None), reload_tx: self.reload_tx.clone(),
1307 }
1308 }
1309}
1310
1311impl AresConfigManager {
1312 pub fn from_config(config: AresConfig) -> Self {
1315 Self {
1316 config: Arc::new(ArcSwap::from_pointee(config)),
1317 config_path: PathBuf::from("test-config.toml"),
1318 watcher: RwLock::new(None),
1319 reload_tx: None,
1320 }
1321 }
1322}
1323
1324#[cfg(test)]
1325mod tests {
1326 use super::*;
1327
1328 fn create_test_config() -> String {
1329 r#"
1330[server]
1331host = "127.0.0.1"
1332port = 3000
1333log_level = "debug"
1334
1335[auth]
1336jwt_secret_env = "TEST_JWT_SECRET"
1337jwt_access_expiry = 900
1338jwt_refresh_expiry = 604800
1339api_key_env = "TEST_API_KEY"
1340
1341[database]
1342url = "./data/test.db"
1343
1344[providers.ollama-local]
1345type = "ollama"
1346base_url = "http://localhost:11434"
1347default_model = "ministral-3:3b"
1348
1349[models.default]
1350provider = "ollama-local"
1351model = "ministral-3:3b"
1352temperature = 0.7
1353max_tokens = 512
1354
1355[tools.calculator]
1356enabled = true
1357description = "Basic calculator"
1358timeout_secs = 10
1359
1360[agents.router]
1361model = "default"
1362tools = []
1363max_tool_iterations = 5
1364
1365[workflows.default]
1366entry_agent = "router"
1367max_depth = 3
1368max_iterations = 5
1369"#
1370 .to_string()
1371 }
1372
1373 #[test]
1374 fn test_parse_config() {
1375 unsafe {
1378 std::env::set_var(
1379 "TEST_JWT_SECRET",
1380 "test-secret-at-least-32-characters-long-at-least-32-characters-long",
1381 );
1382 std::env::set_var("TEST_API_KEY", "test-api-key");
1383 }
1384
1385 let content = create_test_config();
1386 let config: AresConfig = toml::from_str(&content).expect("Failed to parse config");
1387
1388 assert_eq!(config.server.host, "127.0.0.1");
1389 assert_eq!(config.server.port, 3000);
1390 assert!(config.providers.contains_key("ollama-local"));
1391 assert!(config.models.contains_key("default"));
1392 assert!(config.agents.contains_key("router"));
1393 }
1394
1395 #[test]
1396 fn test_validation_missing_provider() {
1397 unsafe {
1399 std::env::set_var("TEST_JWT_SECRET", "test-secret-at-least-32-characters-long");
1400 std::env::set_var("TEST_API_KEY", "test-key");
1401 }
1402
1403 let content = r#"
1404[server]
1405[auth]
1406jwt_secret_env = "TEST_JWT_SECRET"
1407api_key_env = "TEST_API_KEY"
1408[database]
1409[models.test]
1410provider = "nonexistent"
1411model = "test"
1412"#;
1413
1414 let config: AresConfig = toml::from_str(content).unwrap();
1415 let result = config.validate();
1416
1417 assert!(matches!(result, Err(ConfigError::MissingProvider(_, _))));
1418 }
1419
1420 #[test]
1421 fn test_validation_missing_model() {
1422 unsafe {
1424 std::env::set_var("TEST_JWT_SECRET", "test-secret-at-least-32-characters-long");
1425 std::env::set_var("TEST_API_KEY", "test-key");
1426 }
1427
1428 let content = r#"
1429[server]
1430[auth]
1431jwt_secret_env = "TEST_JWT_SECRET"
1432api_key_env = "TEST_API_KEY"
1433[database]
1434[providers.test]
1435type = "ollama"
1436default_model = "ministral-3:3b"
1437[agents.test]
1438model = "nonexistent"
1439"#;
1440
1441 let config: AresConfig = toml::from_str(content).unwrap();
1442 let result = config.validate();
1443
1444 assert!(matches!(result, Err(ConfigError::MissingModel(_, _))));
1445 }
1446
1447 #[test]
1448 fn test_validation_missing_tool() {
1449 unsafe {
1451 std::env::set_var("TEST_JWT_SECRET", "test-secret-at-least-32-characters-long");
1452 std::env::set_var("TEST_API_KEY", "test-key");
1453 }
1454
1455 let content = r#"
1456[server]
1457[auth]
1458jwt_secret_env = "TEST_JWT_SECRET"
1459api_key_env = "TEST_API_KEY"
1460[database]
1461[providers.test]
1462type = "ollama"
1463default_model = "ministral-3:3b"
1464[models.default]
1465provider = "test"
1466model = "ministral-3:3b"
1467[agents.test]
1468model = "default"
1469tools = ["nonexistent_tool"]
1470"#;
1471
1472 let config: AresConfig = toml::from_str(content).unwrap();
1473 let result = config.validate();
1474
1475 assert!(matches!(result, Err(ConfigError::MissingTool(_, _))));
1476 }
1477
1478 #[test]
1479 fn test_validation_missing_workflow_agent() {
1480 unsafe {
1482 std::env::set_var("TEST_JWT_SECRET", "test-secret-at-least-32-characters-long");
1483 std::env::set_var("TEST_API_KEY", "test-key");
1484 }
1485
1486 let content = r#"
1487[server]
1488[auth]
1489jwt_secret_env = "TEST_JWT_SECRET"
1490api_key_env = "TEST_API_KEY"
1491[database]
1492[workflows.test]
1493entry_agent = "nonexistent_agent"
1494"#;
1495
1496 let config: AresConfig = toml::from_str(content).unwrap();
1497 let result = config.validate();
1498
1499 assert!(matches!(result, Err(ConfigError::MissingAgent(_, _))));
1500 }
1501
1502 #[test]
1503 fn test_get_provider() {
1504 let content = create_test_config();
1505 let config: AresConfig = toml::from_str(&content).unwrap();
1506
1507 assert!(config.get_provider("ollama-local").is_some());
1508 assert!(config.get_provider("nonexistent").is_none());
1509 }
1510
1511 #[test]
1512 fn test_get_model() {
1513 let content = create_test_config();
1514 let config: AresConfig = toml::from_str(&content).unwrap();
1515
1516 assert!(config.get_model("default").is_some());
1517 assert!(config.get_model("nonexistent").is_none());
1518 }
1519
1520 #[test]
1521 fn test_get_agent() {
1522 let content = create_test_config();
1523 let config: AresConfig = toml::from_str(&content).unwrap();
1524
1525 assert!(config.get_agent("router").is_some());
1526 assert!(config.get_agent("nonexistent").is_none());
1527 }
1528
1529 #[test]
1530 fn test_get_tool() {
1531 let content = create_test_config();
1532 let config: AresConfig = toml::from_str(&content).unwrap();
1533
1534 assert!(config.get_tool("calculator").is_some());
1535 assert!(config.get_tool("nonexistent").is_none());
1536 }
1537
1538 #[test]
1539 fn test_enabled_tools() {
1540 let content = r#"
1541[server]
1542[auth]
1543jwt_secret_env = "TEST_JWT_SECRET"
1544api_key_env = "TEST_API_KEY"
1545[database]
1546[tools.enabled_tool]
1547enabled = true
1548[tools.disabled_tool]
1549enabled = false
1550"#;
1551
1552 let config: AresConfig = toml::from_str(content).unwrap();
1553 let enabled = config.enabled_tools();
1554
1555 assert!(enabled.contains(&"enabled_tool"));
1556 assert!(!enabled.contains(&"disabled_tool"));
1557 }
1558
1559 #[test]
1560 fn test_defaults() {
1561 let content = r#"
1562[server]
1563[auth]
1564jwt_secret_env = "TEST_JWT_SECRET"
1565api_key_env = "TEST_API_KEY"
1566[database]
1567"#;
1568
1569 let config: AresConfig = toml::from_str(content).unwrap();
1570
1571 assert_eq!(config.server.host, "127.0.0.1");
1573 assert_eq!(config.server.port, 3000);
1574 assert_eq!(config.server.log_level, "info");
1575
1576 assert_eq!(config.auth.jwt_access_expiry, 900);
1578 assert_eq!(config.auth.jwt_refresh_expiry, 604800);
1579
1580 assert_eq!(config.database.url, "./data/ares.db");
1582
1583 assert_eq!(config.rag.embedding_model, "bge-small-en-v1.5");
1585 assert_eq!(config.rag.chunk_size, 200);
1586 assert_eq!(config.rag.chunk_overlap, 50);
1587 assert_eq!(config.rag.vector_store, "ares-vector");
1588 assert_eq!(config.rag.search_strategy, "semantic");
1589 }
1590
1591 #[test]
1592 fn test_config_manager_from_config() {
1593 let content = create_test_config();
1594 let config: AresConfig = toml::from_str(&content).unwrap();
1595
1596 let manager = AresConfigManager::from_config(config.clone());
1597 let loaded = manager.config();
1598
1599 assert_eq!(loaded.server.host, config.server.host);
1600 assert_eq!(loaded.server.port, config.server.port);
1601 }
1602
1603 #[test]
1604 fn test_circular_reference_detection() {
1605 unsafe {
1607 std::env::set_var("TEST_JWT_SECRET", "test-secret-at-least-32-characters-long");
1608 std::env::set_var("TEST_API_KEY", "test-key");
1609 }
1610
1611 let content = r#"
1612[server]
1613[auth]
1614jwt_secret_env = "TEST_JWT_SECRET"
1615api_key_env = "TEST_API_KEY"
1616[database]
1617[providers.test]
1618type = "ollama"
1619default_model = "ministral-3:3b"
1620[models.default]
1621provider = "test"
1622model = "ministral-3:3b"
1623[agents.agent_a]
1624model = "default"
1625[workflows.circular]
1626entry_agent = "agent_a"
1627fallback_agent = "agent_a"
1628"#;
1629
1630 let config: AresConfig = toml::from_str(content).unwrap();
1631 let result = config.validate();
1632
1633 assert!(matches!(result, Err(ConfigError::CircularReference(_))));
1634 }
1635
1636 #[test]
1637 fn test_unused_provider_warning() {
1638 unsafe {
1640 std::env::set_var("TEST_JWT_SECRET", "test-secret-at-least-32-characters-long");
1641 std::env::set_var("TEST_API_KEY", "test-key");
1642 }
1643
1644 let content = r#"
1645[server]
1646[auth]
1647jwt_secret_env = "TEST_JWT_SECRET"
1648api_key_env = "TEST_API_KEY"
1649[database]
1650[providers.used]
1651type = "ollama"
1652default_model = "ministral-3:3b"
1653[providers.unused]
1654type = "ollama"
1655default_model = "ministral-3:3b"
1656[models.default]
1657provider = "used"
1658model = "ministral-3:3b"
1659[agents.router]
1660model = "default"
1661"#;
1662
1663 let config: AresConfig = toml::from_str(content).unwrap();
1664 let warnings = config.validate_with_warnings().unwrap();
1665
1666 assert!(warnings
1667 .iter()
1668 .any(|w| w.kind == ConfigWarningKind::UnusedProvider && w.message.contains("unused")));
1669 }
1670
1671 #[test]
1672 fn test_unused_model_warning() {
1673 unsafe {
1675 std::env::set_var("TEST_JWT_SECRET", "test-secret-at-least-32-characters-long");
1676 std::env::set_var("TEST_API_KEY", "test-key");
1677 }
1678
1679 let content = r#"
1680[server]
1681[auth]
1682jwt_secret_env = "TEST_JWT_SECRET"
1683api_key_env = "TEST_API_KEY"
1684[database]
1685[providers.test]
1686type = "ollama"
1687default_model = "ministral-3:3b"
1688[models.used]
1689provider = "test"
1690model = "ministral-3:3b"
1691[models.unused]
1692provider = "test"
1693model = "other"
1694[agents.router]
1695model = "used"
1696"#;
1697
1698 let config: AresConfig = toml::from_str(content).unwrap();
1699 let warnings = config.validate_with_warnings().unwrap();
1700
1701 assert!(warnings
1702 .iter()
1703 .any(|w| w.kind == ConfigWarningKind::UnusedModel && w.message.contains("unused")));
1704 }
1705
1706 #[test]
1707 fn test_unused_tool_warning() {
1708 unsafe {
1710 std::env::set_var("TEST_JWT_SECRET", "test-secret-at-least-32-characters-long");
1711 std::env::set_var("TEST_API_KEY", "test-key");
1712 }
1713
1714 let content = r#"
1715[server]
1716[auth]
1717jwt_secret_env = "TEST_JWT_SECRET"
1718api_key_env = "TEST_API_KEY"
1719[database]
1720[providers.test]
1721type = "ollama"
1722default_model = "ministral-3:3b"
1723[models.default]
1724provider = "test"
1725model = "ministral-3:3b"
1726[tools.used_tool]
1727enabled = true
1728[tools.unused_tool]
1729enabled = true
1730[agents.router]
1731model = "default"
1732tools = ["used_tool"]
1733"#;
1734
1735 let config: AresConfig = toml::from_str(content).unwrap();
1736 let warnings = config.validate_with_warnings().unwrap();
1737
1738 assert!(warnings
1739 .iter()
1740 .any(|w| w.kind == ConfigWarningKind::UnusedTool && w.message.contains("unused_tool")));
1741 }
1742
1743 #[test]
1744 fn test_unused_agent_warning() {
1745 unsafe {
1747 std::env::set_var("TEST_JWT_SECRET", "test-secret-at-least-32-characters-long");
1748 std::env::set_var("TEST_API_KEY", "test-key");
1749 }
1750
1751 let content = r#"
1752[server]
1753[auth]
1754jwt_secret_env = "TEST_JWT_SECRET"
1755api_key_env = "TEST_API_KEY"
1756[database]
1757[providers.test]
1758type = "ollama"
1759default_model = "ministral-3:3b"
1760[models.default]
1761provider = "test"
1762model = "ministral-3:3b"
1763[agents.router]
1764model = "default"
1765[agents.orphaned]
1766model = "default"
1767[workflows.test_flow]
1768entry_agent = "router"
1769"#;
1770
1771 let config: AresConfig = toml::from_str(content).unwrap();
1772 let warnings = config.validate_with_warnings().unwrap();
1773
1774 assert!(warnings
1775 .iter()
1776 .any(|w| w.kind == ConfigWarningKind::UnusedAgent && w.message.contains("orphaned")));
1777 }
1778
1779 #[test]
1780 fn test_no_warnings_for_fully_connected_config() {
1781 unsafe {
1783 std::env::set_var("TEST_JWT_SECRET", "test-secret-at-least-32-characters-long");
1784 std::env::set_var("TEST_API_KEY", "test-key");
1785 }
1786
1787 let content = r#"
1788[server]
1789[auth]
1790jwt_secret_env = "TEST_JWT_SECRET"
1791api_key_env = "TEST_API_KEY"
1792[database]
1793[providers.test]
1794type = "ollama"
1795default_model = "ministral-3:3b"
1796[models.default]
1797provider = "test"
1798model = "ministral-3:3b"
1799[tools.calc]
1800enabled = true
1801[agents.router]
1802model = "default"
1803tools = ["calc"]
1804[workflows.main]
1805entry_agent = "router"
1806"#;
1807
1808 let config: AresConfig = toml::from_str(content).unwrap();
1809 let warnings = config.validate_with_warnings().unwrap();
1810
1811 assert!(
1812 warnings.is_empty(),
1813 "Expected no warnings but got: {:?}",
1814 warnings
1815 );
1816 }
1817}