1use arc_swap::ArcSwap;
12use notify::{Event, RecommendedWatcher, RecursiveMode, Watcher};
13use parking_lot::RwLock;
14use serde::{Deserialize, Serialize};
15use std::collections::HashMap;
16use std::fs;
17use std::path::{Path, PathBuf};
18use std::sync::Arc;
19use std::time::Duration;
20use tokio::sync::mpsc;
21use tracing::{error, info, warn};
22
23#[derive(Debug, Clone, Serialize, Deserialize)]
25pub struct AresConfig {
26 pub server: ServerConfig,
28
29 pub auth: AuthConfig,
31
32 pub database: DatabaseConfig,
34
35 #[serde(default)]
37 pub providers: HashMap<String, ProviderConfig>,
38
39 #[serde(default)]
42 pub models: HashMap<String, ModelConfig>,
43
44 #[serde(default)]
47 pub tools: HashMap<String, ToolConfig>,
48
49 #[serde(default)]
52 pub agents: HashMap<String, AgentConfig>,
53
54 #[serde(default)]
57 pub workflows: HashMap<String, WorkflowConfig>,
58
59 #[serde(default)]
61 pub rag: RagConfig,
62
63 #[serde(default)]
65 pub config: DynamicConfigPaths,
66}
67
68#[derive(Debug, Clone, Serialize, Deserialize)]
72pub struct ServerConfig {
73 #[serde(default = "default_host")]
75 pub host: String,
76
77 #[serde(default = "default_port")]
79 pub port: u16,
80
81 #[serde(default = "default_log_level")]
83 pub log_level: String,
84
85 #[serde(default = "default_cors_origins")]
88 pub cors_origins: Vec<String>,
89
90 #[serde(default = "default_rate_limit")]
92 pub rate_limit_per_second: u32,
93
94 #[serde(default = "default_rate_limit_burst")]
96 pub rate_limit_burst: u32,
97}
98
99fn default_host() -> String {
100 "127.0.0.1".to_string()
101}
102
103fn default_port() -> u16 {
104 3000
105}
106
107fn default_log_level() -> String {
108 "info".to_string()
109}
110
111fn default_cors_origins() -> Vec<String> {
112 vec![
113 "https://admin.dirmacs.com".to_string(),
114 "https://eruka.dirmacs.com".to_string(),
115 ]
116}
117
118fn default_rate_limit() -> u32 {
119 100
120}
121
122fn default_rate_limit_burst() -> u32 {
123 10
124}
125
126impl Default for ServerConfig {
127 fn default() -> Self {
128 Self {
129 host: default_host(),
130 port: default_port(),
131 log_level: default_log_level(),
132 cors_origins: default_cors_origins(),
133 rate_limit_per_second: default_rate_limit(),
134 rate_limit_burst: default_rate_limit_burst(),
135 }
136 }
137}
138
139#[derive(Debug, Clone, Serialize, Deserialize)]
143pub struct AuthConfig {
144 pub jwt_secret_env: String,
146
147 #[serde(default = "default_jwt_access_expiry")]
149 pub jwt_access_expiry: i64,
150
151 #[serde(default = "default_jwt_refresh_expiry")]
153 pub jwt_refresh_expiry: i64,
154
155 pub api_key_env: String,
157}
158
159fn default_jwt_access_expiry() -> i64 {
160 900
161}
162
163fn default_jwt_refresh_expiry() -> i64 {
164 604800
165}
166
167impl Default for AuthConfig {
168 fn default() -> Self {
169 Self {
170 jwt_secret_env: "JWT_SECRET".to_string(),
171 jwt_access_expiry: default_jwt_access_expiry(),
172 jwt_refresh_expiry: default_jwt_refresh_expiry(),
173 api_key_env: "API_KEY".to_string(),
174 }
175 }
176}
177
178#[derive(Debug, Clone, Serialize, Deserialize)]
182pub struct DatabaseConfig {
183 #[serde(default = "default_database_url")]
185 pub url: String,
186
187 pub qdrant: Option<QdrantConfig>,
189}
190
191fn default_database_url() -> String {
192 "postgres://postgres:postgres@localhost:5432/ares".to_string()
193}
194
195impl Default for DatabaseConfig {
196 fn default() -> Self {
197 Self {
198 url: default_database_url(),
199 qdrant: None,
200 }
201 }
202}
203
204#[derive(Debug, Clone, Serialize, Deserialize)]
206pub struct QdrantConfig {
207 #[serde(default = "default_qdrant_url")]
209 pub url: String,
210
211 pub api_key_env: Option<String>,
213}
214
215fn default_qdrant_url() -> String {
216 "http://localhost:6334".to_string()
217}
218
219impl Default for QdrantConfig {
220 fn default() -> Self {
221 Self {
222 url: default_qdrant_url(),
223 api_key_env: None,
224 }
225 }
226}
227
228#[derive(Debug, Clone, Serialize, Deserialize)]
232#[serde(tag = "type", rename_all = "lowercase")]
233pub enum ProviderConfig {
234 Ollama {
236 #[serde(default = "default_ollama_url")]
238 base_url: String,
239 default_model: String,
241 },
242 OpenAI {
244 api_key_env: String,
246 #[serde(default = "default_openai_base")]
248 api_base: String,
249 default_model: String,
251 },
252 LlamaCpp {
254 model_path: String,
256 #[serde(default = "default_n_ctx")]
258 n_ctx: u32,
259 #[serde(default = "default_n_threads")]
261 n_threads: u32,
262 #[serde(default = "default_max_tokens")]
264 max_tokens: u32,
265 },
266 Anthropic {
268 api_key_env: String,
270 default_model: String,
272 },
273}
274
275fn default_ollama_url() -> String {
276 "http://localhost:11434".to_string()
277}
278
279fn default_openai_base() -> String {
280 "https://api.openai.com/v1".to_string()
281}
282
283fn default_n_ctx() -> u32 {
284 4096
285}
286
287fn default_n_threads() -> u32 {
288 4
289}
290
291fn default_max_tokens() -> u32 {
292 512
293}
294
295#[derive(Debug, Clone, Serialize, Deserialize)]
299pub struct ModelConfig {
300 pub provider: String,
302
303 pub model: String,
305
306 #[serde(default = "default_temperature")]
308 pub temperature: f32,
309
310 #[serde(default = "default_model_max_tokens")]
312 pub max_tokens: u32,
313
314 pub top_p: Option<f32>,
316
317 pub frequency_penalty: Option<f32>,
319
320 pub presence_penalty: Option<f32>,
322}
323
324fn default_temperature() -> f32 {
325 0.7
326}
327
328fn default_model_max_tokens() -> u32 {
329 512
330}
331
332#[derive(Debug, Clone, Serialize, Deserialize)]
336pub struct ToolConfig {
337 #[serde(default = "default_true")]
339 pub enabled: bool,
340
341 #[serde(default)]
343 pub description: Option<String>,
344
345 #[serde(default = "default_tool_timeout")]
347 pub timeout_secs: u64,
348
349 #[serde(flatten)]
351 pub extra: HashMap<String, toml::Value>,
352}
353
354fn default_true() -> bool {
355 true
356}
357
358fn default_tool_timeout() -> u64 {
359 30
360}
361
362impl Default for ToolConfig {
363 fn default() -> Self {
364 Self {
365 enabled: true,
366 description: None,
367 timeout_secs: default_tool_timeout(),
368 extra: HashMap::new(),
369 }
370 }
371}
372
373#[derive(Debug, Clone, Serialize, Deserialize)]
377pub struct AgentConfig {
378 pub model: String,
380
381 #[serde(default)]
383 pub system_prompt: Option<String>,
384
385 #[serde(default)]
387 pub tools: Vec<String>,
388
389 #[serde(default = "default_max_tool_iterations")]
391 pub max_tool_iterations: usize,
392
393 #[serde(default)]
395 pub parallel_tools: bool,
396
397 #[serde(flatten)]
399 pub extra: HashMap<String, toml::Value>,
400}
401
402fn default_max_tool_iterations() -> usize {
403 10
404}
405
406#[derive(Debug, Clone, Serialize, Deserialize)]
410pub struct WorkflowConfig {
411 pub entry_agent: String,
413
414 pub fallback_agent: Option<String>,
416
417 #[serde(default = "default_max_depth")]
419 pub max_depth: u8,
420
421 #[serde(default = "default_max_iterations")]
423 pub max_iterations: u8,
424
425 #[serde(default)]
427 pub parallel_subagents: bool,
428}
429
430fn default_max_depth() -> u8 {
431 3
432}
433
434fn default_max_iterations() -> u8 {
435 5
436}
437
438#[derive(Debug, Clone, Serialize, Deserialize)]
442pub struct RagConfig {
443 #[serde(default = "default_vector_store")]
446 pub vector_store: String,
447
448 #[serde(default = "default_vector_path")]
450 pub vector_path: String,
451
452 #[serde(default = "default_embedding_model")]
457 pub embedding_model: String,
458
459 #[serde(default)]
461 pub sparse_embeddings: bool,
462
463 #[serde(default = "default_sparse_model")]
465 pub sparse_model: String,
466
467 #[serde(default = "default_chunking_strategy")]
470 pub chunking_strategy: String,
471
472 #[serde(default = "default_chunk_size")]
474 pub chunk_size: usize,
475
476 #[serde(default = "default_chunk_overlap")]
478 pub chunk_overlap: usize,
479
480 #[serde(default = "default_min_chunk_size")]
482 pub min_chunk_size: usize,
483
484 #[serde(default = "default_search_strategy")]
487 pub search_strategy: String,
488
489 #[serde(default = "default_search_limit")]
491 pub search_limit: usize,
492
493 #[serde(default)]
495 pub search_threshold: f32,
496
497 #[serde(default)]
499 pub hybrid_weights: Option<HybridWeightsConfig>,
500
501 #[serde(default)]
504 pub rerank_enabled: bool,
505
506 #[serde(default = "default_reranker_model")]
509 pub reranker_model: String,
510
511 #[serde(default = "default_rerank_weight")]
513 pub rerank_weight: f32,
514}
515
516#[derive(Debug, Clone, Serialize, Deserialize)]
518pub struct HybridWeightsConfig {
519 #[serde(default = "default_semantic_weight")]
521 pub semantic: f32,
522 #[serde(default = "default_bm25_weight")]
524 pub bm25: f32,
525 #[serde(default = "default_fuzzy_weight")]
527 pub fuzzy: f32,
528}
529
530impl Default for HybridWeightsConfig {
531 fn default() -> Self {
532 Self {
533 semantic: 0.5,
534 bm25: 0.3,
535 fuzzy: 0.2,
536 }
537 }
538}
539
540fn default_semantic_weight() -> f32 {
541 0.5
542}
543
544fn default_bm25_weight() -> f32 {
545 0.3
546}
547
548fn default_fuzzy_weight() -> f32 {
549 0.2
550}
551
552fn default_vector_store() -> String {
553 "ares-vector".to_string()
554}
555
556fn default_vector_path() -> String {
557 "./data/vectors".to_string()
558}
559
560fn default_embedding_model() -> String {
561 "bge-small-en-v1.5".to_string()
562}
563
564fn default_sparse_model() -> String {
565 "splade-pp-en-v1".to_string()
566}
567
568fn default_chunking_strategy() -> String {
569 "word".to_string()
570}
571
572fn default_chunk_size() -> usize {
573 200
574}
575
576fn default_chunk_overlap() -> usize {
577 50
578}
579
580fn default_min_chunk_size() -> usize {
581 20
582}
583
584fn default_search_strategy() -> String {
585 "semantic".to_string()
586}
587
588fn default_search_limit() -> usize {
589 10
590}
591
592fn default_reranker_model() -> String {
593 "bge-reranker-base".to_string()
594}
595
596fn default_rerank_weight() -> f32 {
597 0.6
598}
599
600impl Default for RagConfig {
601 fn default() -> Self {
602 Self {
603 vector_store: default_vector_store(),
604 vector_path: default_vector_path(),
605 embedding_model: default_embedding_model(),
606 sparse_embeddings: false,
607 sparse_model: default_sparse_model(),
608 chunking_strategy: default_chunking_strategy(),
609 chunk_size: default_chunk_size(),
610 chunk_overlap: default_chunk_overlap(),
611 min_chunk_size: default_min_chunk_size(),
612 search_strategy: default_search_strategy(),
613 search_limit: default_search_limit(),
614 search_threshold: 0.0,
615 hybrid_weights: None,
616 rerank_enabled: false,
617 reranker_model: default_reranker_model(),
618 rerank_weight: default_rerank_weight(),
619 }
620 }
621}
622
623#[derive(Debug, Clone, Serialize, Deserialize)]
631pub struct DynamicConfigPaths {
632 #[serde(default = "default_agents_dir")]
634 pub agents_dir: std::path::PathBuf,
635
636 #[serde(default = "default_workflows_dir")]
638 pub workflows_dir: std::path::PathBuf,
639
640 #[serde(default = "default_models_dir")]
642 pub models_dir: std::path::PathBuf,
643
644 #[serde(default = "default_tools_dir")]
646 pub tools_dir: std::path::PathBuf,
647
648 #[serde(default = "default_mcps_dir")]
650 pub mcps_dir: std::path::PathBuf,
651
652 #[serde(default = "default_hot_reload")]
654 pub hot_reload: bool,
655
656 #[serde(default = "default_watch_interval")]
658 pub watch_interval_ms: u64,
659}
660
661fn default_agents_dir() -> std::path::PathBuf {
662 std::path::PathBuf::from("config/agents")
663}
664
665fn default_workflows_dir() -> std::path::PathBuf {
666 std::path::PathBuf::from("config/workflows")
667}
668
669fn default_models_dir() -> std::path::PathBuf {
670 std::path::PathBuf::from("config/models")
671}
672
673fn default_tools_dir() -> std::path::PathBuf {
674 std::path::PathBuf::from("config/tools")
675}
676
677fn default_mcps_dir() -> std::path::PathBuf {
678 std::path::PathBuf::from("config/mcps")
679}
680
681fn default_hot_reload() -> bool {
682 true
683}
684
685fn default_watch_interval() -> u64 {
686 1000
687}
688
689impl Default for DynamicConfigPaths {
690 fn default() -> Self {
691 Self {
692 agents_dir: default_agents_dir(),
693 workflows_dir: default_workflows_dir(),
694 models_dir: default_models_dir(),
695 tools_dir: default_tools_dir(),
696 mcps_dir: default_mcps_dir(),
697 hot_reload: default_hot_reload(),
698 watch_interval_ms: default_watch_interval(),
699 }
700 }
701}
702
703#[derive(Debug, Clone)]
707pub struct ConfigWarning {
708 pub kind: ConfigWarningKind,
710
711 pub message: String,
713}
714
715#[derive(Debug, Clone, PartialEq)]
717pub enum ConfigWarningKind {
718 UnusedProvider,
720
721 UnusedModel,
723
724 UnusedTool,
726
727 UnusedAgent,
729}
730
731impl std::fmt::Display for ConfigWarning {
732 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
733 write!(f, "{}", self.message)
734 }
735}
736
737#[derive(Debug, thiserror::Error)]
739pub enum ConfigError {
740 #[error("Configuration file not found: {0}")]
742 FileNotFound(PathBuf),
743
744 #[error("Failed to read configuration file: {0}")]
746 ReadError(#[from] std::io::Error),
747
748 #[error("Failed to parse TOML: {0}")]
750 ParseError(#[from] toml::de::Error),
751
752 #[error("Validation error: {0}")]
754 ValidationError(String),
755
756 #[error("Environment variable '{0}' referenced in config is not set")]
758 MissingEnvVar(String),
759
760 #[error("Provider '{0}' referenced by model '{1}' does not exist")]
762 MissingProvider(String, String),
763
764 #[error("Model '{0}' referenced by agent '{1}' does not exist")]
766 MissingModel(String, String),
767
768 #[error("Agent '{0}' referenced by workflow '{1}' does not exist")]
770 MissingAgent(String, String),
771
772 #[error("Tool '{0}' referenced by agent '{1}' does not exist")]
774 MissingTool(String, String),
775
776 #[error("Circular reference detected: {0}")]
778 CircularReference(String),
779
780 #[error("Watch error: {0}")]
782 WatchError(#[from] notify::Error),
783}
784
785impl AresConfig {
786 pub fn load<P: AsRef<Path>>(path: P) -> Result<Self, ConfigError> {
793 let path = path.as_ref();
794
795 if !path.exists() {
796 return Err(ConfigError::FileNotFound(path.to_path_buf()));
797 }
798
799 let content = fs::read_to_string(path)?;
800 let config: AresConfig = toml::from_str(&content)?;
801
802 config.validate()?;
804
805 Ok(config)
806 }
807
808 pub fn load_unchecked<P: AsRef<Path>>(path: P) -> Result<Self, ConfigError> {
814 let path = path.as_ref();
815
816 if !path.exists() {
817 return Err(ConfigError::FileNotFound(path.to_path_buf()));
818 }
819
820 let content = fs::read_to_string(path)?;
821 let config: AresConfig = toml::from_str(&content)?;
822
823 Ok(config)
824 }
825
826 pub fn validate(&self) -> Result<(), ConfigError> {
828 self.validate_env_var(&self.auth.jwt_secret_env)?;
830 self.validate_env_var(&self.auth.api_key_env)?;
831
832 if let Some(ref qdrant) = self.database.qdrant {
834 if let Some(ref env) = qdrant.api_key_env {
835 self.validate_env_var(env)?;
836 }
837 }
838
839 for (name, provider) in &self.providers {
841 match provider {
842 ProviderConfig::OpenAI { api_key_env, .. } => {
843 self.validate_env_var(api_key_env)?;
844 }
845 ProviderConfig::Anthropic { api_key_env, .. } => {
846 self.validate_env_var(api_key_env)?;
847 }
848 ProviderConfig::LlamaCpp { model_path, .. } => {
849 if !Path::new(model_path).exists() {
851 return Err(ConfigError::ValidationError(format!(
852 "LlamaCpp model path does not exist: {} (provider: {})",
853 model_path, name
854 )));
855 }
856 }
857 ProviderConfig::Ollama { .. } => {
858 }
860 }
861 }
862
863 for (model_name, model_config) in &self.models {
865 if !self.providers.contains_key(&model_config.provider) {
866 return Err(ConfigError::MissingProvider(
867 model_config.provider.clone(),
868 model_name.clone(),
869 ));
870 }
871 }
872
873 for (agent_name, agent_config) in &self.agents {
875 if !self.models.contains_key(&agent_config.model) {
876 return Err(ConfigError::MissingModel(
877 agent_config.model.clone(),
878 agent_name.clone(),
879 ));
880 }
881
882 for tool_name in &agent_config.tools {
883 if !self.tools.contains_key(tool_name) {
884 return Err(ConfigError::MissingTool(
885 tool_name.clone(),
886 agent_name.clone(),
887 ));
888 }
889 }
890 }
891
892 for (workflow_name, workflow_config) in &self.workflows {
894 if !self.agents.contains_key(&workflow_config.entry_agent) {
895 return Err(ConfigError::MissingAgent(
896 workflow_config.entry_agent.clone(),
897 workflow_name.clone(),
898 ));
899 }
900
901 if let Some(ref fallback) = workflow_config.fallback_agent {
902 if !self.agents.contains_key(fallback) {
903 return Err(ConfigError::MissingAgent(
904 fallback.clone(),
905 workflow_name.clone(),
906 ));
907 }
908 }
909 }
910
911 self.detect_circular_references()?;
913
914 Ok(())
915 }
916
917 fn detect_circular_references(&self) -> Result<(), ConfigError> {
922 use std::collections::HashSet;
923
924 for (workflow_name, workflow_config) in &self.workflows {
925 let mut visited = HashSet::new();
926 let mut current = Some(workflow_config.entry_agent.as_str());
927
928 while let Some(agent_name) = current {
929 if visited.contains(agent_name) {
930 return Err(ConfigError::CircularReference(format!(
931 "Circular reference detected in workflow '{}': agent '{}' appears multiple times in the chain",
932 workflow_name, agent_name
933 )));
934 }
935 visited.insert(agent_name);
936
937 current = None;
940
941 if let Some(ref fallback) = workflow_config.fallback_agent {
943 if fallback == &workflow_config.entry_agent {
944 return Err(ConfigError::CircularReference(format!(
945 "Workflow '{}' has entry_agent '{}' that equals fallback_agent",
946 workflow_name, workflow_config.entry_agent
947 )));
948 }
949 }
950 }
951 }
952
953 Ok(())
954 }
955
956 pub fn validate_with_warnings(&self) -> Result<Vec<ConfigWarning>, ConfigError> {
960 self.validate()?;
962
963 let mut warnings = Vec::new();
965
966 warnings.extend(self.check_unused_providers());
968
969 warnings.extend(self.check_unused_models());
971
972 warnings.extend(self.check_unused_tools());
974
975 warnings.extend(self.check_unused_agents());
977
978 Ok(warnings)
979 }
980
981 fn check_unused_providers(&self) -> Vec<ConfigWarning> {
983 use std::collections::HashSet;
984
985 let referenced: HashSet<_> = self.models.values().map(|m| m.provider.as_str()).collect();
986
987 self.providers
988 .keys()
989 .filter(|name| !referenced.contains(name.as_str()))
990 .map(|name| ConfigWarning {
991 kind: ConfigWarningKind::UnusedProvider,
992 message: format!(
993 "Provider '{}' is defined but not referenced by any model",
994 name
995 ),
996 })
997 .collect()
998 }
999
1000 fn check_unused_models(&self) -> Vec<ConfigWarning> {
1002 use std::collections::HashSet;
1003
1004 let referenced: HashSet<_> = self.agents.values().map(|a| a.model.as_str()).collect();
1005
1006 self.models
1007 .keys()
1008 .filter(|name| !referenced.contains(name.as_str()))
1009 .map(|name| ConfigWarning {
1010 kind: ConfigWarningKind::UnusedModel,
1011 message: format!(
1012 "Model '{}' is defined but not referenced by any agent",
1013 name
1014 ),
1015 })
1016 .collect()
1017 }
1018
1019 fn check_unused_tools(&self) -> Vec<ConfigWarning> {
1021 use std::collections::HashSet;
1022
1023 let referenced: HashSet<_> = self
1024 .agents
1025 .values()
1026 .flat_map(|a| a.tools.iter().map(|t| t.as_str()))
1027 .collect();
1028
1029 self.tools
1030 .keys()
1031 .filter(|name| !referenced.contains(name.as_str()))
1032 .map(|name| ConfigWarning {
1033 kind: ConfigWarningKind::UnusedTool,
1034 message: format!("Tool '{}' is defined but not referenced by any agent", name),
1035 })
1036 .collect()
1037 }
1038
1039 fn check_unused_agents(&self) -> Vec<ConfigWarning> {
1041 use std::collections::HashSet;
1042
1043 let referenced: HashSet<_> = self
1044 .workflows
1045 .values()
1046 .flat_map(|w| {
1047 let mut refs = vec![w.entry_agent.as_str()];
1048 if let Some(ref fallback) = w.fallback_agent {
1049 refs.push(fallback.as_str());
1050 }
1051 refs
1052 })
1053 .collect();
1054
1055 let system_agents: HashSet<&str> = ["orchestrator", "router"].into_iter().collect();
1057
1058 self.agents
1059 .keys()
1060 .filter(|name| {
1061 !referenced.contains(name.as_str()) && !system_agents.contains(name.as_str())
1062 })
1063 .map(|name| ConfigWarning {
1064 kind: ConfigWarningKind::UnusedAgent,
1065 message: format!(
1066 "Agent '{}' is defined but not referenced by any workflow",
1067 name
1068 ),
1069 })
1070 .collect()
1071 }
1072
1073 fn validate_env_var(&self, name: &str) -> Result<(), ConfigError> {
1074 std::env::var(name).map_err(|_| ConfigError::MissingEnvVar(name.to_string()))?;
1075 Ok(())
1076 }
1077
1078 pub fn resolve_env(&self, env_name: &str) -> Option<String> {
1080 std::env::var(env_name).ok()
1081 }
1082
1083 const JWT_SECRET_MIN_LENGTH: usize = 32;
1085
1086 pub fn jwt_secret(&self) -> Result<String, ConfigError> {
1093 let secret = self
1094 .resolve_env(&self.auth.jwt_secret_env)
1095 .ok_or_else(|| ConfigError::MissingEnvVar(self.auth.jwt_secret_env.clone()))?;
1096
1097 if secret.len() < Self::JWT_SECRET_MIN_LENGTH {
1098 return Err(ConfigError::ValidationError(format!(
1099 "JWT_SECRET must be at least {} characters for security (current: {} chars). \
1100 Use a cryptographically random string, e.g.: openssl rand -base64 32",
1101 Self::JWT_SECRET_MIN_LENGTH,
1102 secret.len()
1103 )));
1104 }
1105
1106 Ok(secret)
1107 }
1108
1109 pub fn api_key(&self) -> Result<String, ConfigError> {
1111 self.resolve_env(&self.auth.api_key_env)
1112 .ok_or_else(|| ConfigError::MissingEnvVar(self.auth.api_key_env.clone()))
1113 }
1114
1115 pub fn get_provider(&self, name: &str) -> Option<&ProviderConfig> {
1117 self.providers.get(name)
1118 }
1119
1120 pub fn get_model(&self, name: &str) -> Option<&ModelConfig> {
1122 self.models.get(name)
1123 }
1124
1125 pub fn get_agent(&self, name: &str) -> Option<&AgentConfig> {
1127 self.agents.get(name)
1128 }
1129
1130 pub fn get_tool(&self, name: &str) -> Option<&ToolConfig> {
1132 self.tools.get(name)
1133 }
1134
1135 pub fn get_workflow(&self, name: &str) -> Option<&WorkflowConfig> {
1137 self.workflows.get(name)
1138 }
1139
1140 pub fn enabled_tools(&self) -> Vec<&str> {
1142 self.tools
1143 .iter()
1144 .filter(|(_, config)| config.enabled)
1145 .map(|(name, _)| name.as_str())
1146 .collect()
1147 }
1148
1149 pub fn agent_tools(&self, agent_name: &str) -> Vec<&str> {
1151 self.get_agent(agent_name)
1152 .map(|agent| {
1153 agent
1154 .tools
1155 .iter()
1156 .filter(|t| self.get_tool(t).map(|tc| tc.enabled).unwrap_or(false))
1157 .map(|s| s.as_str())
1158 .collect()
1159 })
1160 .unwrap_or_default()
1161 }
1162}
1163
1164pub struct AresConfigManager {
1168 config: Arc<ArcSwap<AresConfig>>,
1169 config_path: PathBuf,
1170 watcher: RwLock<Option<RecommendedWatcher>>,
1171 reload_tx: Option<mpsc::UnboundedSender<()>>,
1172}
1173
1174impl AresConfigManager {
1175 pub fn new<P: AsRef<Path>>(path: P) -> Result<Self, ConfigError> {
1181 let path = path.as_ref();
1183 let path = if path.is_absolute() {
1184 path.to_path_buf()
1185 } else {
1186 std::env::current_dir()
1187 .map_err(ConfigError::ReadError)?
1188 .join(path)
1189 };
1190
1191 let config = AresConfig::load(&path)?;
1192
1193 Ok(Self {
1194 config: Arc::new(ArcSwap::from_pointee(config)),
1195 config_path: path,
1196 watcher: RwLock::new(None),
1197 reload_tx: None,
1198 })
1199 }
1200
1201 pub fn config(&self) -> Arc<AresConfig> {
1203 self.config.load_full()
1204 }
1205
1206 pub fn reload(&self) -> Result<(), ConfigError> {
1208 info!("Reloading configuration from {:?}", self.config_path);
1209
1210 let new_config = AresConfig::load(&self.config_path)?;
1211 self.config.store(Arc::new(new_config));
1212
1213 info!("Configuration reloaded successfully");
1214 Ok(())
1215 }
1216
1217 pub fn start_watching(&mut self) -> Result<(), ConfigError> {
1219 let (tx, mut rx) = mpsc::unbounded_channel::<()>();
1220 self.reload_tx = Some(tx.clone());
1221
1222 let config_path = self.config_path.clone();
1223 let config_arc = Arc::clone(&self.config);
1224
1225 let mut watcher = notify::recommended_watcher(move |res: Result<Event, notify::Error>| {
1227 match res {
1228 Ok(event) => {
1229 if event.kind.is_modify() || event.kind.is_create() {
1230 let _ = tx.send(());
1232 }
1233 }
1234 Err(e) => {
1235 error!("Config watcher error: {:?}", e);
1236 }
1237 }
1238 })?;
1239
1240 if let Some(parent) = self.config_path.parent() {
1242 watcher.watch(parent, RecursiveMode::NonRecursive)?;
1243 }
1244
1245 *self.watcher.write() = Some(watcher);
1246
1247 let config_path_clone = config_path.clone();
1249 tokio::spawn(async move {
1250 let mut last_reload = std::time::Instant::now();
1251 let debounce_duration = Duration::from_millis(500);
1252
1253 while rx.recv().await.is_some() {
1254 if last_reload.elapsed() < debounce_duration {
1256 continue;
1257 }
1258
1259 tokio::time::sleep(Duration::from_millis(100)).await;
1261
1262 match AresConfig::load(&config_path_clone) {
1263 Ok(new_config) => {
1264 config_arc.store(Arc::new(new_config));
1265 info!("Configuration hot-reloaded successfully");
1266 last_reload = std::time::Instant::now();
1267 }
1268 Err(e) => {
1269 warn!(
1270 "Failed to hot-reload config: {}. Keeping previous config.",
1271 e
1272 );
1273 }
1274 }
1275 }
1276 });
1277
1278 info!("Configuration hot-reload watcher started");
1279 Ok(())
1280 }
1281
1282 pub fn stop_watching(&self) {
1284 *self.watcher.write() = None;
1285 info!("Configuration hot-reload watcher stopped");
1286 }
1287}
1288
1289impl Clone for AresConfigManager {
1290 fn clone(&self) -> Self {
1291 Self {
1292 config: Arc::clone(&self.config),
1293 config_path: self.config_path.clone(),
1294 watcher: RwLock::new(None), reload_tx: self.reload_tx.clone(),
1296 }
1297 }
1298}
1299
1300impl AresConfigManager {
1301 pub fn from_config(config: AresConfig) -> Self {
1304 Self {
1305 config: Arc::new(ArcSwap::from_pointee(config)),
1306 config_path: PathBuf::from("test-config.toml"),
1307 watcher: RwLock::new(None),
1308 reload_tx: None,
1309 }
1310 }
1311}
1312
1313#[cfg(test)]
1314mod tests {
1315 use super::*;
1316
1317 fn create_test_config() -> String {
1318 r#"
1319[server]
1320host = "127.0.0.1"
1321port = 3000
1322log_level = "debug"
1323
1324[auth]
1325jwt_secret_env = "TEST_JWT_SECRET"
1326jwt_access_expiry = 900
1327jwt_refresh_expiry = 604800
1328api_key_env = "TEST_API_KEY"
1329
1330[database]
1331url = "./data/test.db"
1332
1333[providers.ollama-local]
1334type = "ollama"
1335base_url = "http://localhost:11434"
1336default_model = "ministral-3:3b"
1337
1338[models.default]
1339provider = "ollama-local"
1340model = "ministral-3:3b"
1341temperature = 0.7
1342max_tokens = 512
1343
1344[tools.calculator]
1345enabled = true
1346description = "Basic calculator"
1347timeout_secs = 10
1348
1349[agents.router]
1350model = "default"
1351tools = []
1352max_tool_iterations = 5
1353
1354[workflows.default]
1355entry_agent = "router"
1356max_depth = 3
1357max_iterations = 5
1358"#
1359 .to_string()
1360 }
1361
1362 #[test]
1363 fn test_parse_config() {
1364 unsafe {
1367 std::env::set_var(
1368 "TEST_JWT_SECRET",
1369 "test-secret-at-least-32-characters-long-at-least-32-characters-long",
1370 );
1371 std::env::set_var("TEST_API_KEY", "test-api-key");
1372 }
1373
1374 let content = create_test_config();
1375 let config: AresConfig = toml::from_str(&content).expect("Failed to parse config");
1376
1377 assert_eq!(config.server.host, "127.0.0.1");
1378 assert_eq!(config.server.port, 3000);
1379 assert!(config.providers.contains_key("ollama-local"));
1380 assert!(config.models.contains_key("default"));
1381 assert!(config.agents.contains_key("router"));
1382 }
1383
1384 #[test]
1385 fn test_validation_missing_provider() {
1386 unsafe {
1388 std::env::set_var("TEST_JWT_SECRET", "test-secret-at-least-32-characters-long");
1389 std::env::set_var("TEST_API_KEY", "test-key");
1390 }
1391
1392 let content = r#"
1393[server]
1394[auth]
1395jwt_secret_env = "TEST_JWT_SECRET"
1396api_key_env = "TEST_API_KEY"
1397[database]
1398[models.test]
1399provider = "nonexistent"
1400model = "test"
1401"#;
1402
1403 let config: AresConfig = toml::from_str(content).unwrap();
1404 let result = config.validate();
1405
1406 assert!(matches!(result, Err(ConfigError::MissingProvider(_, _))));
1407 }
1408
1409 #[test]
1410 fn test_validation_missing_model() {
1411 unsafe {
1413 std::env::set_var("TEST_JWT_SECRET", "test-secret-at-least-32-characters-long");
1414 std::env::set_var("TEST_API_KEY", "test-key");
1415 }
1416
1417 let content = r#"
1418[server]
1419[auth]
1420jwt_secret_env = "TEST_JWT_SECRET"
1421api_key_env = "TEST_API_KEY"
1422[database]
1423[providers.test]
1424type = "ollama"
1425default_model = "ministral-3:3b"
1426[agents.test]
1427model = "nonexistent"
1428"#;
1429
1430 let config: AresConfig = toml::from_str(content).unwrap();
1431 let result = config.validate();
1432
1433 assert!(matches!(result, Err(ConfigError::MissingModel(_, _))));
1434 }
1435
1436 #[test]
1437 fn test_validation_missing_tool() {
1438 unsafe {
1440 std::env::set_var("TEST_JWT_SECRET", "test-secret-at-least-32-characters-long");
1441 std::env::set_var("TEST_API_KEY", "test-key");
1442 }
1443
1444 let content = r#"
1445[server]
1446[auth]
1447jwt_secret_env = "TEST_JWT_SECRET"
1448api_key_env = "TEST_API_KEY"
1449[database]
1450[providers.test]
1451type = "ollama"
1452default_model = "ministral-3:3b"
1453[models.default]
1454provider = "test"
1455model = "ministral-3:3b"
1456[agents.test]
1457model = "default"
1458tools = ["nonexistent_tool"]
1459"#;
1460
1461 let config: AresConfig = toml::from_str(content).unwrap();
1462 let result = config.validate();
1463
1464 assert!(matches!(result, Err(ConfigError::MissingTool(_, _))));
1465 }
1466
1467 #[test]
1468 fn test_validation_missing_workflow_agent() {
1469 unsafe {
1471 std::env::set_var("TEST_JWT_SECRET", "test-secret-at-least-32-characters-long");
1472 std::env::set_var("TEST_API_KEY", "test-key");
1473 }
1474
1475 let content = r#"
1476[server]
1477[auth]
1478jwt_secret_env = "TEST_JWT_SECRET"
1479api_key_env = "TEST_API_KEY"
1480[database]
1481[workflows.test]
1482entry_agent = "nonexistent_agent"
1483"#;
1484
1485 let config: AresConfig = toml::from_str(content).unwrap();
1486 let result = config.validate();
1487
1488 assert!(matches!(result, Err(ConfigError::MissingAgent(_, _))));
1489 }
1490
1491 #[test]
1492 fn test_get_provider() {
1493 let content = create_test_config();
1494 let config: AresConfig = toml::from_str(&content).unwrap();
1495
1496 assert!(config.get_provider("ollama-local").is_some());
1497 assert!(config.get_provider("nonexistent").is_none());
1498 }
1499
1500 #[test]
1501 fn test_get_model() {
1502 let content = create_test_config();
1503 let config: AresConfig = toml::from_str(&content).unwrap();
1504
1505 assert!(config.get_model("default").is_some());
1506 assert!(config.get_model("nonexistent").is_none());
1507 }
1508
1509 #[test]
1510 fn test_get_agent() {
1511 let content = create_test_config();
1512 let config: AresConfig = toml::from_str(&content).unwrap();
1513
1514 assert!(config.get_agent("router").is_some());
1515 assert!(config.get_agent("nonexistent").is_none());
1516 }
1517
1518 #[test]
1519 fn test_get_tool() {
1520 let content = create_test_config();
1521 let config: AresConfig = toml::from_str(&content).unwrap();
1522
1523 assert!(config.get_tool("calculator").is_some());
1524 assert!(config.get_tool("nonexistent").is_none());
1525 }
1526
1527 #[test]
1528 fn test_enabled_tools() {
1529 let content = r#"
1530[server]
1531[auth]
1532jwt_secret_env = "TEST_JWT_SECRET"
1533api_key_env = "TEST_API_KEY"
1534[database]
1535[tools.enabled_tool]
1536enabled = true
1537[tools.disabled_tool]
1538enabled = false
1539"#;
1540
1541 let config: AresConfig = toml::from_str(content).unwrap();
1542 let enabled = config.enabled_tools();
1543
1544 assert!(enabled.contains(&"enabled_tool"));
1545 assert!(!enabled.contains(&"disabled_tool"));
1546 }
1547
1548 #[test]
1549 fn test_defaults() {
1550 let content = r#"
1551[server]
1552[auth]
1553jwt_secret_env = "TEST_JWT_SECRET"
1554api_key_env = "TEST_API_KEY"
1555[database]
1556"#;
1557
1558 let config: AresConfig = toml::from_str(content).unwrap();
1559
1560 assert_eq!(config.server.host, "127.0.0.1");
1562 assert_eq!(config.server.port, 3000);
1563 assert_eq!(config.server.log_level, "info");
1564
1565 assert_eq!(config.auth.jwt_access_expiry, 900);
1567 assert_eq!(config.auth.jwt_refresh_expiry, 604800);
1568
1569 assert_eq!(config.database.url, "./data/ares.db");
1571
1572 assert_eq!(config.rag.embedding_model, "bge-small-en-v1.5");
1574 assert_eq!(config.rag.chunk_size, 200);
1575 assert_eq!(config.rag.chunk_overlap, 50);
1576 assert_eq!(config.rag.vector_store, "ares-vector");
1577 assert_eq!(config.rag.search_strategy, "semantic");
1578 }
1579
1580 #[test]
1581 fn test_config_manager_from_config() {
1582 let content = create_test_config();
1583 let config: AresConfig = toml::from_str(&content).unwrap();
1584
1585 let manager = AresConfigManager::from_config(config.clone());
1586 let loaded = manager.config();
1587
1588 assert_eq!(loaded.server.host, config.server.host);
1589 assert_eq!(loaded.server.port, config.server.port);
1590 }
1591
1592 #[test]
1593 fn test_circular_reference_detection() {
1594 unsafe {
1596 std::env::set_var("TEST_JWT_SECRET", "test-secret-at-least-32-characters-long");
1597 std::env::set_var("TEST_API_KEY", "test-key");
1598 }
1599
1600 let content = r#"
1601[server]
1602[auth]
1603jwt_secret_env = "TEST_JWT_SECRET"
1604api_key_env = "TEST_API_KEY"
1605[database]
1606[providers.test]
1607type = "ollama"
1608default_model = "ministral-3:3b"
1609[models.default]
1610provider = "test"
1611model = "ministral-3:3b"
1612[agents.agent_a]
1613model = "default"
1614[workflows.circular]
1615entry_agent = "agent_a"
1616fallback_agent = "agent_a"
1617"#;
1618
1619 let config: AresConfig = toml::from_str(content).unwrap();
1620 let result = config.validate();
1621
1622 assert!(matches!(result, Err(ConfigError::CircularReference(_))));
1623 }
1624
1625 #[test]
1626 fn test_unused_provider_warning() {
1627 unsafe {
1629 std::env::set_var("TEST_JWT_SECRET", "test-secret-at-least-32-characters-long");
1630 std::env::set_var("TEST_API_KEY", "test-key");
1631 }
1632
1633 let content = r#"
1634[server]
1635[auth]
1636jwt_secret_env = "TEST_JWT_SECRET"
1637api_key_env = "TEST_API_KEY"
1638[database]
1639[providers.used]
1640type = "ollama"
1641default_model = "ministral-3:3b"
1642[providers.unused]
1643type = "ollama"
1644default_model = "ministral-3:3b"
1645[models.default]
1646provider = "used"
1647model = "ministral-3:3b"
1648[agents.router]
1649model = "default"
1650"#;
1651
1652 let config: AresConfig = toml::from_str(content).unwrap();
1653 let warnings = config.validate_with_warnings().unwrap();
1654
1655 assert!(warnings
1656 .iter()
1657 .any(|w| w.kind == ConfigWarningKind::UnusedProvider && w.message.contains("unused")));
1658 }
1659
1660 #[test]
1661 fn test_unused_model_warning() {
1662 unsafe {
1664 std::env::set_var("TEST_JWT_SECRET", "test-secret-at-least-32-characters-long");
1665 std::env::set_var("TEST_API_KEY", "test-key");
1666 }
1667
1668 let content = r#"
1669[server]
1670[auth]
1671jwt_secret_env = "TEST_JWT_SECRET"
1672api_key_env = "TEST_API_KEY"
1673[database]
1674[providers.test]
1675type = "ollama"
1676default_model = "ministral-3:3b"
1677[models.used]
1678provider = "test"
1679model = "ministral-3:3b"
1680[models.unused]
1681provider = "test"
1682model = "other"
1683[agents.router]
1684model = "used"
1685"#;
1686
1687 let config: AresConfig = toml::from_str(content).unwrap();
1688 let warnings = config.validate_with_warnings().unwrap();
1689
1690 assert!(warnings
1691 .iter()
1692 .any(|w| w.kind == ConfigWarningKind::UnusedModel && w.message.contains("unused")));
1693 }
1694
1695 #[test]
1696 fn test_unused_tool_warning() {
1697 unsafe {
1699 std::env::set_var("TEST_JWT_SECRET", "test-secret-at-least-32-characters-long");
1700 std::env::set_var("TEST_API_KEY", "test-key");
1701 }
1702
1703 let content = r#"
1704[server]
1705[auth]
1706jwt_secret_env = "TEST_JWT_SECRET"
1707api_key_env = "TEST_API_KEY"
1708[database]
1709[providers.test]
1710type = "ollama"
1711default_model = "ministral-3:3b"
1712[models.default]
1713provider = "test"
1714model = "ministral-3:3b"
1715[tools.used_tool]
1716enabled = true
1717[tools.unused_tool]
1718enabled = true
1719[agents.router]
1720model = "default"
1721tools = ["used_tool"]
1722"#;
1723
1724 let config: AresConfig = toml::from_str(content).unwrap();
1725 let warnings = config.validate_with_warnings().unwrap();
1726
1727 assert!(warnings
1728 .iter()
1729 .any(|w| w.kind == ConfigWarningKind::UnusedTool && w.message.contains("unused_tool")));
1730 }
1731
1732 #[test]
1733 fn test_unused_agent_warning() {
1734 unsafe {
1736 std::env::set_var("TEST_JWT_SECRET", "test-secret-at-least-32-characters-long");
1737 std::env::set_var("TEST_API_KEY", "test-key");
1738 }
1739
1740 let content = r#"
1741[server]
1742[auth]
1743jwt_secret_env = "TEST_JWT_SECRET"
1744api_key_env = "TEST_API_KEY"
1745[database]
1746[providers.test]
1747type = "ollama"
1748default_model = "ministral-3:3b"
1749[models.default]
1750provider = "test"
1751model = "ministral-3:3b"
1752[agents.router]
1753model = "default"
1754[agents.orphaned]
1755model = "default"
1756[workflows.test_flow]
1757entry_agent = "router"
1758"#;
1759
1760 let config: AresConfig = toml::from_str(content).unwrap();
1761 let warnings = config.validate_with_warnings().unwrap();
1762
1763 assert!(warnings
1764 .iter()
1765 .any(|w| w.kind == ConfigWarningKind::UnusedAgent && w.message.contains("orphaned")));
1766 }
1767
1768 #[test]
1769 fn test_no_warnings_for_fully_connected_config() {
1770 unsafe {
1772 std::env::set_var("TEST_JWT_SECRET", "test-secret-at-least-32-characters-long");
1773 std::env::set_var("TEST_API_KEY", "test-key");
1774 }
1775
1776 let content = r#"
1777[server]
1778[auth]
1779jwt_secret_env = "TEST_JWT_SECRET"
1780api_key_env = "TEST_API_KEY"
1781[database]
1782[providers.test]
1783type = "ollama"
1784default_model = "ministral-3:3b"
1785[models.default]
1786provider = "test"
1787model = "ministral-3:3b"
1788[tools.calc]
1789enabled = true
1790[agents.router]
1791model = "default"
1792tools = ["calc"]
1793[workflows.main]
1794entry_agent = "router"
1795"#;
1796
1797 let config: AresConfig = toml::from_str(content).unwrap();
1798 let warnings = config.validate_with_warnings().unwrap();
1799
1800 assert!(
1801 warnings.is_empty(),
1802 "Expected no warnings but got: {:?}",
1803 warnings
1804 );
1805 }
1806}