1use arc_swap::ArcSwap;
12use notify::{Event, RecommendedWatcher, RecursiveMode, Watcher};
13use parking_lot::RwLock;
14use serde::{Deserialize, Serialize};
15use std::collections::HashMap;
16use std::fs;
17use std::path::{Path, PathBuf};
18use std::sync::Arc;
19use std::time::Duration;
20use tokio::sync::mpsc;
21use tracing::{error, info, warn};
22
23#[derive(Debug, Clone, Serialize, Deserialize)]
25pub struct AresConfig {
26 pub server: ServerConfig,
28
29 pub auth: AuthConfig,
31
32 pub database: DatabaseConfig,
34
35 #[serde(default)]
37 pub providers: HashMap<String, ProviderConfig>,
38
39 #[serde(default)]
42 pub models: HashMap<String, ModelConfig>,
43
44 #[serde(default)]
47 pub tools: HashMap<String, ToolConfig>,
48
49 #[serde(default)]
52 pub agents: HashMap<String, AgentConfig>,
53
54 #[serde(default)]
57 pub workflows: HashMap<String, WorkflowConfig>,
58
59 #[serde(default)]
61 pub rag: RagConfig,
62
63 #[serde(default)]
65 pub config: DynamicConfigPaths,
66}
67
68#[derive(Debug, Clone, Serialize, Deserialize)]
72pub struct ServerConfig {
73 #[serde(default = "default_host")]
75 pub host: String,
76
77 #[serde(default = "default_port")]
79 pub port: u16,
80
81 #[serde(default = "default_log_level")]
83 pub log_level: String,
84}
85
86fn default_host() -> String {
87 "127.0.0.1".to_string()
88}
89
90fn default_port() -> u16 {
91 3000
92}
93
94fn default_log_level() -> String {
95 "info".to_string()
96}
97
98impl Default for ServerConfig {
99 fn default() -> Self {
100 Self {
101 host: default_host(),
102 port: default_port(),
103 log_level: default_log_level(),
104 }
105 }
106}
107
108#[derive(Debug, Clone, Serialize, Deserialize)]
112pub struct AuthConfig {
113 pub jwt_secret_env: String,
115
116 #[serde(default = "default_jwt_access_expiry")]
118 pub jwt_access_expiry: i64,
119
120 #[serde(default = "default_jwt_refresh_expiry")]
122 pub jwt_refresh_expiry: i64,
123
124 pub api_key_env: String,
126}
127
128fn default_jwt_access_expiry() -> i64 {
129 900
130}
131
132fn default_jwt_refresh_expiry() -> i64 {
133 604800
134}
135
136impl Default for AuthConfig {
137 fn default() -> Self {
138 Self {
139 jwt_secret_env: "JWT_SECRET".to_string(),
140 jwt_access_expiry: default_jwt_access_expiry(),
141 jwt_refresh_expiry: default_jwt_refresh_expiry(),
142 api_key_env: "API_KEY".to_string(),
143 }
144 }
145}
146
147#[derive(Debug, Clone, Serialize, Deserialize)]
151pub struct DatabaseConfig {
152 #[serde(default = "default_database_url")]
154 pub url: String,
155
156 pub turso_url_env: Option<String>,
158
159 pub turso_token_env: Option<String>,
161
162 pub qdrant: Option<QdrantConfig>,
164}
165
166fn default_database_url() -> String {
167 "./data/ares.db".to_string()
168}
169
170impl Default for DatabaseConfig {
171 fn default() -> Self {
172 Self {
173 url: default_database_url(),
174 turso_url_env: None,
175 turso_token_env: None,
176 qdrant: None,
177 }
178 }
179}
180
181#[derive(Debug, Clone, Serialize, Deserialize)]
183pub struct QdrantConfig {
184 #[serde(default = "default_qdrant_url")]
186 pub url: String,
187
188 pub api_key_env: Option<String>,
190}
191
192fn default_qdrant_url() -> String {
193 "http://localhost:6334".to_string()
194}
195
196impl Default for QdrantConfig {
197 fn default() -> Self {
198 Self {
199 url: default_qdrant_url(),
200 api_key_env: None,
201 }
202 }
203}
204
205#[derive(Debug, Clone, Serialize, Deserialize)]
209#[serde(tag = "type", rename_all = "lowercase")]
210pub enum ProviderConfig {
211 Ollama {
213 #[serde(default = "default_ollama_url")]
215 base_url: String,
216 default_model: String,
218 },
219 OpenAI {
221 api_key_env: String,
223 #[serde(default = "default_openai_base")]
225 api_base: String,
226 default_model: String,
228 },
229 LlamaCpp {
231 model_path: String,
233 #[serde(default = "default_n_ctx")]
235 n_ctx: u32,
236 #[serde(default = "default_n_threads")]
238 n_threads: u32,
239 #[serde(default = "default_max_tokens")]
241 max_tokens: u32,
242 },
243}
244
245fn default_ollama_url() -> String {
246 "http://localhost:11434".to_string()
247}
248
249fn default_openai_base() -> String {
250 "https://api.openai.com/v1".to_string()
251}
252
253fn default_n_ctx() -> u32 {
254 4096
255}
256
257fn default_n_threads() -> u32 {
258 4
259}
260
261fn default_max_tokens() -> u32 {
262 512
263}
264
265#[derive(Debug, Clone, Serialize, Deserialize)]
269pub struct ModelConfig {
270 pub provider: String,
272
273 pub model: String,
275
276 #[serde(default = "default_temperature")]
278 pub temperature: f32,
279
280 #[serde(default = "default_model_max_tokens")]
282 pub max_tokens: u32,
283
284 pub top_p: Option<f32>,
286
287 pub frequency_penalty: Option<f32>,
289
290 pub presence_penalty: Option<f32>,
292}
293
294fn default_temperature() -> f32 {
295 0.7
296}
297
298fn default_model_max_tokens() -> u32 {
299 512
300}
301
302#[derive(Debug, Clone, Serialize, Deserialize)]
306pub struct ToolConfig {
307 #[serde(default = "default_true")]
309 pub enabled: bool,
310
311 #[serde(default)]
313 pub description: Option<String>,
314
315 #[serde(default = "default_tool_timeout")]
317 pub timeout_secs: u64,
318
319 #[serde(flatten)]
321 pub extra: HashMap<String, toml::Value>,
322}
323
324fn default_true() -> bool {
325 true
326}
327
328fn default_tool_timeout() -> u64 {
329 30
330}
331
332impl Default for ToolConfig {
333 fn default() -> Self {
334 Self {
335 enabled: true,
336 description: None,
337 timeout_secs: default_tool_timeout(),
338 extra: HashMap::new(),
339 }
340 }
341}
342
343#[derive(Debug, Clone, Serialize, Deserialize)]
347pub struct AgentConfig {
348 pub model: String,
350
351 #[serde(default)]
353 pub system_prompt: Option<String>,
354
355 #[serde(default)]
357 pub tools: Vec<String>,
358
359 #[serde(default = "default_max_tool_iterations")]
361 pub max_tool_iterations: usize,
362
363 #[serde(default)]
365 pub parallel_tools: bool,
366
367 #[serde(flatten)]
369 pub extra: HashMap<String, toml::Value>,
370}
371
372fn default_max_tool_iterations() -> usize {
373 10
374}
375
376#[derive(Debug, Clone, Serialize, Deserialize)]
380pub struct WorkflowConfig {
381 pub entry_agent: String,
383
384 pub fallback_agent: Option<String>,
386
387 #[serde(default = "default_max_depth")]
389 pub max_depth: u8,
390
391 #[serde(default = "default_max_iterations")]
393 pub max_iterations: u8,
394
395 #[serde(default)]
397 pub parallel_subagents: bool,
398}
399
400fn default_max_depth() -> u8 {
401 3
402}
403
404fn default_max_iterations() -> u8 {
405 5
406}
407
408#[derive(Debug, Clone, Serialize, Deserialize)]
412pub struct RagConfig {
413 #[serde(default = "default_vector_store")]
416 pub vector_store: String,
417
418 #[serde(default = "default_vector_path")]
420 pub vector_path: String,
421
422 #[serde(default = "default_embedding_model")]
427 pub embedding_model: String,
428
429 #[serde(default)]
431 pub sparse_embeddings: bool,
432
433 #[serde(default = "default_sparse_model")]
435 pub sparse_model: String,
436
437 #[serde(default = "default_chunking_strategy")]
440 pub chunking_strategy: String,
441
442 #[serde(default = "default_chunk_size")]
444 pub chunk_size: usize,
445
446 #[serde(default = "default_chunk_overlap")]
448 pub chunk_overlap: usize,
449
450 #[serde(default = "default_min_chunk_size")]
452 pub min_chunk_size: usize,
453
454 #[serde(default = "default_search_strategy")]
457 pub search_strategy: String,
458
459 #[serde(default = "default_search_limit")]
461 pub search_limit: usize,
462
463 #[serde(default)]
465 pub search_threshold: f32,
466
467 #[serde(default)]
469 pub hybrid_weights: Option<HybridWeightsConfig>,
470
471 #[serde(default)]
474 pub rerank_enabled: bool,
475
476 #[serde(default = "default_reranker_model")]
479 pub reranker_model: String,
480
481 #[serde(default = "default_rerank_weight")]
483 pub rerank_weight: f32,
484}
485
486#[derive(Debug, Clone, Serialize, Deserialize)]
488pub struct HybridWeightsConfig {
489 #[serde(default = "default_semantic_weight")]
491 pub semantic: f32,
492 #[serde(default = "default_bm25_weight")]
494 pub bm25: f32,
495 #[serde(default = "default_fuzzy_weight")]
497 pub fuzzy: f32,
498}
499
500impl Default for HybridWeightsConfig {
501 fn default() -> Self {
502 Self {
503 semantic: 0.5,
504 bm25: 0.3,
505 fuzzy: 0.2,
506 }
507 }
508}
509
510fn default_semantic_weight() -> f32 {
511 0.5
512}
513
514fn default_bm25_weight() -> f32 {
515 0.3
516}
517
518fn default_fuzzy_weight() -> f32 {
519 0.2
520}
521
522fn default_vector_store() -> String {
523 "ares-vector".to_string()
524}
525
526fn default_vector_path() -> String {
527 "./data/vectors".to_string()
528}
529
530fn default_embedding_model() -> String {
531 "bge-small-en-v1.5".to_string()
532}
533
534fn default_sparse_model() -> String {
535 "splade-pp-en-v1".to_string()
536}
537
538fn default_chunking_strategy() -> String {
539 "word".to_string()
540}
541
542fn default_chunk_size() -> usize {
543 200
544}
545
546fn default_chunk_overlap() -> usize {
547 50
548}
549
550fn default_min_chunk_size() -> usize {
551 20
552}
553
554fn default_search_strategy() -> String {
555 "semantic".to_string()
556}
557
558fn default_search_limit() -> usize {
559 10
560}
561
562fn default_reranker_model() -> String {
563 "bge-reranker-base".to_string()
564}
565
566fn default_rerank_weight() -> f32 {
567 0.6
568}
569
570impl Default for RagConfig {
571 fn default() -> Self {
572 Self {
573 vector_store: default_vector_store(),
574 vector_path: default_vector_path(),
575 embedding_model: default_embedding_model(),
576 sparse_embeddings: false,
577 sparse_model: default_sparse_model(),
578 chunking_strategy: default_chunking_strategy(),
579 chunk_size: default_chunk_size(),
580 chunk_overlap: default_chunk_overlap(),
581 min_chunk_size: default_min_chunk_size(),
582 search_strategy: default_search_strategy(),
583 search_limit: default_search_limit(),
584 search_threshold: 0.0,
585 hybrid_weights: None,
586 rerank_enabled: false,
587 reranker_model: default_reranker_model(),
588 rerank_weight: default_rerank_weight(),
589 }
590 }
591}
592
593#[derive(Debug, Clone, Serialize, Deserialize)]
601pub struct DynamicConfigPaths {
602 #[serde(default = "default_agents_dir")]
604 pub agents_dir: std::path::PathBuf,
605
606 #[serde(default = "default_workflows_dir")]
608 pub workflows_dir: std::path::PathBuf,
609
610 #[serde(default = "default_models_dir")]
612 pub models_dir: std::path::PathBuf,
613
614 #[serde(default = "default_tools_dir")]
616 pub tools_dir: std::path::PathBuf,
617
618 #[serde(default = "default_mcps_dir")]
620 pub mcps_dir: std::path::PathBuf,
621
622 #[serde(default = "default_hot_reload")]
624 pub hot_reload: bool,
625
626 #[serde(default = "default_watch_interval")]
628 pub watch_interval_ms: u64,
629}
630
631fn default_agents_dir() -> std::path::PathBuf {
632 std::path::PathBuf::from("config/agents")
633}
634
635fn default_workflows_dir() -> std::path::PathBuf {
636 std::path::PathBuf::from("config/workflows")
637}
638
639fn default_models_dir() -> std::path::PathBuf {
640 std::path::PathBuf::from("config/models")
641}
642
643fn default_tools_dir() -> std::path::PathBuf {
644 std::path::PathBuf::from("config/tools")
645}
646
647fn default_mcps_dir() -> std::path::PathBuf {
648 std::path::PathBuf::from("config/mcps")
649}
650
651fn default_hot_reload() -> bool {
652 true
653}
654
655fn default_watch_interval() -> u64 {
656 1000
657}
658
659impl Default for DynamicConfigPaths {
660 fn default() -> Self {
661 Self {
662 agents_dir: default_agents_dir(),
663 workflows_dir: default_workflows_dir(),
664 models_dir: default_models_dir(),
665 tools_dir: default_tools_dir(),
666 mcps_dir: default_mcps_dir(),
667 hot_reload: default_hot_reload(),
668 watch_interval_ms: default_watch_interval(),
669 }
670 }
671}
672
673#[derive(Debug, Clone)]
677pub struct ConfigWarning {
678 pub kind: ConfigWarningKind,
680
681 pub message: String,
683}
684
685#[derive(Debug, Clone, PartialEq)]
687pub enum ConfigWarningKind {
688 UnusedProvider,
690
691 UnusedModel,
693
694 UnusedTool,
696
697 UnusedAgent,
699}
700
701impl std::fmt::Display for ConfigWarning {
702 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
703 write!(f, "{}", self.message)
704 }
705}
706
707#[derive(Debug, thiserror::Error)]
709pub enum ConfigError {
710 #[error("Configuration file not found: {0}")]
712 FileNotFound(PathBuf),
713
714 #[error("Failed to read configuration file: {0}")]
716 ReadError(#[from] std::io::Error),
717
718 #[error("Failed to parse TOML: {0}")]
720 ParseError(#[from] toml::de::Error),
721
722 #[error("Validation error: {0}")]
724 ValidationError(String),
725
726 #[error("Environment variable '{0}' referenced in config is not set")]
728 MissingEnvVar(String),
729
730 #[error("Provider '{0}' referenced by model '{1}' does not exist")]
732 MissingProvider(String, String),
733
734 #[error("Model '{0}' referenced by agent '{1}' does not exist")]
736 MissingModel(String, String),
737
738 #[error("Agent '{0}' referenced by workflow '{1}' does not exist")]
740 MissingAgent(String, String),
741
742 #[error("Tool '{0}' referenced by agent '{1}' does not exist")]
744 MissingTool(String, String),
745
746 #[error("Circular reference detected: {0}")]
748 CircularReference(String),
749
750 #[error("Watch error: {0}")]
752 WatchError(#[from] notify::Error),
753}
754
755impl AresConfig {
756 pub fn load<P: AsRef<Path>>(path: P) -> Result<Self, ConfigError> {
763 let path = path.as_ref();
764
765 if !path.exists() {
766 return Err(ConfigError::FileNotFound(path.to_path_buf()));
767 }
768
769 let content = fs::read_to_string(path)?;
770 let config: AresConfig = toml::from_str(&content)?;
771
772 config.validate()?;
774
775 Ok(config)
776 }
777
778 pub fn load_unchecked<P: AsRef<Path>>(path: P) -> Result<Self, ConfigError> {
784 let path = path.as_ref();
785
786 if !path.exists() {
787 return Err(ConfigError::FileNotFound(path.to_path_buf()));
788 }
789
790 let content = fs::read_to_string(path)?;
791 let config: AresConfig = toml::from_str(&content)?;
792
793 Ok(config)
794 }
795
796 pub fn validate(&self) -> Result<(), ConfigError> {
798 self.validate_env_var(&self.auth.jwt_secret_env)?;
800 self.validate_env_var(&self.auth.api_key_env)?;
801
802 if let Some(ref env) = self.database.turso_url_env {
804 self.validate_env_var(env)?;
805 }
806 if let Some(ref env) = self.database.turso_token_env {
807 self.validate_env_var(env)?;
808 }
809 if let Some(ref qdrant) = self.database.qdrant {
810 if let Some(ref env) = qdrant.api_key_env {
811 self.validate_env_var(env)?;
812 }
813 }
814
815 for (name, provider) in &self.providers {
817 match provider {
818 ProviderConfig::OpenAI { api_key_env, .. } => {
819 self.validate_env_var(api_key_env)?;
820 }
821 ProviderConfig::LlamaCpp { model_path, .. } => {
822 if !Path::new(model_path).exists() {
824 return Err(ConfigError::ValidationError(format!(
825 "LlamaCpp model path does not exist: {} (provider: {})",
826 model_path, name
827 )));
828 }
829 }
830 ProviderConfig::Ollama { .. } => {
831 }
833 }
834 }
835
836 for (model_name, model_config) in &self.models {
838 if !self.providers.contains_key(&model_config.provider) {
839 return Err(ConfigError::MissingProvider(
840 model_config.provider.clone(),
841 model_name.clone(),
842 ));
843 }
844 }
845
846 for (agent_name, agent_config) in &self.agents {
848 if !self.models.contains_key(&agent_config.model) {
849 return Err(ConfigError::MissingModel(
850 agent_config.model.clone(),
851 agent_name.clone(),
852 ));
853 }
854
855 for tool_name in &agent_config.tools {
856 if !self.tools.contains_key(tool_name) {
857 return Err(ConfigError::MissingTool(
858 tool_name.clone(),
859 agent_name.clone(),
860 ));
861 }
862 }
863 }
864
865 for (workflow_name, workflow_config) in &self.workflows {
867 if !self.agents.contains_key(&workflow_config.entry_agent) {
868 return Err(ConfigError::MissingAgent(
869 workflow_config.entry_agent.clone(),
870 workflow_name.clone(),
871 ));
872 }
873
874 if let Some(ref fallback) = workflow_config.fallback_agent {
875 if !self.agents.contains_key(fallback) {
876 return Err(ConfigError::MissingAgent(
877 fallback.clone(),
878 workflow_name.clone(),
879 ));
880 }
881 }
882 }
883
884 self.detect_circular_references()?;
886
887 Ok(())
888 }
889
890 fn detect_circular_references(&self) -> Result<(), ConfigError> {
895 use std::collections::HashSet;
896
897 for (workflow_name, workflow_config) in &self.workflows {
898 let mut visited = HashSet::new();
899 let mut current = Some(workflow_config.entry_agent.as_str());
900
901 while let Some(agent_name) = current {
902 if visited.contains(agent_name) {
903 return Err(ConfigError::CircularReference(format!(
904 "Circular reference detected in workflow '{}': agent '{}' appears multiple times in the chain",
905 workflow_name, agent_name
906 )));
907 }
908 visited.insert(agent_name);
909
910 current = None;
913
914 if let Some(ref fallback) = workflow_config.fallback_agent {
916 if fallback == &workflow_config.entry_agent {
917 return Err(ConfigError::CircularReference(format!(
918 "Workflow '{}' has entry_agent '{}' that equals fallback_agent",
919 workflow_name, workflow_config.entry_agent
920 )));
921 }
922 }
923 }
924 }
925
926 Ok(())
927 }
928
929 pub fn validate_with_warnings(&self) -> Result<Vec<ConfigWarning>, ConfigError> {
933 self.validate()?;
935
936 let mut warnings = Vec::new();
938
939 warnings.extend(self.check_unused_providers());
941
942 warnings.extend(self.check_unused_models());
944
945 warnings.extend(self.check_unused_tools());
947
948 warnings.extend(self.check_unused_agents());
950
951 Ok(warnings)
952 }
953
954 fn check_unused_providers(&self) -> Vec<ConfigWarning> {
956 use std::collections::HashSet;
957
958 let referenced: HashSet<_> = self.models.values().map(|m| m.provider.as_str()).collect();
959
960 self.providers
961 .keys()
962 .filter(|name| !referenced.contains(name.as_str()))
963 .map(|name| ConfigWarning {
964 kind: ConfigWarningKind::UnusedProvider,
965 message: format!(
966 "Provider '{}' is defined but not referenced by any model",
967 name
968 ),
969 })
970 .collect()
971 }
972
973 fn check_unused_models(&self) -> Vec<ConfigWarning> {
975 use std::collections::HashSet;
976
977 let referenced: HashSet<_> = self.agents.values().map(|a| a.model.as_str()).collect();
978
979 self.models
980 .keys()
981 .filter(|name| !referenced.contains(name.as_str()))
982 .map(|name| ConfigWarning {
983 kind: ConfigWarningKind::UnusedModel,
984 message: format!(
985 "Model '{}' is defined but not referenced by any agent",
986 name
987 ),
988 })
989 .collect()
990 }
991
992 fn check_unused_tools(&self) -> Vec<ConfigWarning> {
994 use std::collections::HashSet;
995
996 let referenced: HashSet<_> = self
997 .agents
998 .values()
999 .flat_map(|a| a.tools.iter().map(|t| t.as_str()))
1000 .collect();
1001
1002 self.tools
1003 .keys()
1004 .filter(|name| !referenced.contains(name.as_str()))
1005 .map(|name| ConfigWarning {
1006 kind: ConfigWarningKind::UnusedTool,
1007 message: format!("Tool '{}' is defined but not referenced by any agent", name),
1008 })
1009 .collect()
1010 }
1011
1012 fn check_unused_agents(&self) -> Vec<ConfigWarning> {
1014 use std::collections::HashSet;
1015
1016 let referenced: HashSet<_> = self
1017 .workflows
1018 .values()
1019 .flat_map(|w| {
1020 let mut refs = vec![w.entry_agent.as_str()];
1021 if let Some(ref fallback) = w.fallback_agent {
1022 refs.push(fallback.as_str());
1023 }
1024 refs
1025 })
1026 .collect();
1027
1028 let system_agents: HashSet<&str> = ["orchestrator", "router"].into_iter().collect();
1030
1031 self.agents
1032 .keys()
1033 .filter(|name| {
1034 !referenced.contains(name.as_str()) && !system_agents.contains(name.as_str())
1035 })
1036 .map(|name| ConfigWarning {
1037 kind: ConfigWarningKind::UnusedAgent,
1038 message: format!(
1039 "Agent '{}' is defined but not referenced by any workflow",
1040 name
1041 ),
1042 })
1043 .collect()
1044 }
1045
1046 fn validate_env_var(&self, name: &str) -> Result<(), ConfigError> {
1047 std::env::var(name).map_err(|_| ConfigError::MissingEnvVar(name.to_string()))?;
1048 Ok(())
1049 }
1050
1051 pub fn resolve_env(&self, env_name: &str) -> Option<String> {
1053 std::env::var(env_name).ok()
1054 }
1055
1056 pub fn jwt_secret(&self) -> Result<String, ConfigError> {
1058 self.resolve_env(&self.auth.jwt_secret_env)
1059 .ok_or_else(|| ConfigError::MissingEnvVar(self.auth.jwt_secret_env.clone()))
1060 }
1061
1062 pub fn api_key(&self) -> Result<String, ConfigError> {
1064 self.resolve_env(&self.auth.api_key_env)
1065 .ok_or_else(|| ConfigError::MissingEnvVar(self.auth.api_key_env.clone()))
1066 }
1067
1068 pub fn get_provider(&self, name: &str) -> Option<&ProviderConfig> {
1070 self.providers.get(name)
1071 }
1072
1073 pub fn get_model(&self, name: &str) -> Option<&ModelConfig> {
1075 self.models.get(name)
1076 }
1077
1078 pub fn get_agent(&self, name: &str) -> Option<&AgentConfig> {
1080 self.agents.get(name)
1081 }
1082
1083 pub fn get_tool(&self, name: &str) -> Option<&ToolConfig> {
1085 self.tools.get(name)
1086 }
1087
1088 pub fn get_workflow(&self, name: &str) -> Option<&WorkflowConfig> {
1090 self.workflows.get(name)
1091 }
1092
1093 pub fn enabled_tools(&self) -> Vec<&str> {
1095 self.tools
1096 .iter()
1097 .filter(|(_, config)| config.enabled)
1098 .map(|(name, _)| name.as_str())
1099 .collect()
1100 }
1101
1102 pub fn agent_tools(&self, agent_name: &str) -> Vec<&str> {
1104 self.get_agent(agent_name)
1105 .map(|agent| {
1106 agent
1107 .tools
1108 .iter()
1109 .filter(|t| self.get_tool(t).map(|tc| tc.enabled).unwrap_or(false))
1110 .map(|s| s.as_str())
1111 .collect()
1112 })
1113 .unwrap_or_default()
1114 }
1115}
1116
1117pub struct AresConfigManager {
1121 config: Arc<ArcSwap<AresConfig>>,
1122 config_path: PathBuf,
1123 watcher: RwLock<Option<RecommendedWatcher>>,
1124 reload_tx: Option<mpsc::UnboundedSender<()>>,
1125}
1126
1127impl AresConfigManager {
1128 pub fn new<P: AsRef<Path>>(path: P) -> Result<Self, ConfigError> {
1134 let path = path.as_ref();
1136 let path = if path.is_absolute() {
1137 path.to_path_buf()
1138 } else {
1139 std::env::current_dir()
1140 .map_err(ConfigError::ReadError)?
1141 .join(path)
1142 };
1143
1144 let config = AresConfig::load(&path)?;
1145
1146 Ok(Self {
1147 config: Arc::new(ArcSwap::from_pointee(config)),
1148 config_path: path,
1149 watcher: RwLock::new(None),
1150 reload_tx: None,
1151 })
1152 }
1153
1154 pub fn config(&self) -> Arc<AresConfig> {
1156 self.config.load_full()
1157 }
1158
1159 pub fn reload(&self) -> Result<(), ConfigError> {
1161 info!("Reloading configuration from {:?}", self.config_path);
1162
1163 let new_config = AresConfig::load(&self.config_path)?;
1164 self.config.store(Arc::new(new_config));
1165
1166 info!("Configuration reloaded successfully");
1167 Ok(())
1168 }
1169
1170 pub fn start_watching(&mut self) -> Result<(), ConfigError> {
1172 let (tx, mut rx) = mpsc::unbounded_channel::<()>();
1173 self.reload_tx = Some(tx.clone());
1174
1175 let config_path = self.config_path.clone();
1176 let config_arc = Arc::clone(&self.config);
1177
1178 let mut watcher = notify::recommended_watcher(move |res: Result<Event, notify::Error>| {
1180 match res {
1181 Ok(event) => {
1182 if event.kind.is_modify() || event.kind.is_create() {
1183 let _ = tx.send(());
1185 }
1186 }
1187 Err(e) => {
1188 error!("Config watcher error: {:?}", e);
1189 }
1190 }
1191 })?;
1192
1193 if let Some(parent) = self.config_path.parent() {
1195 watcher.watch(parent, RecursiveMode::NonRecursive)?;
1196 }
1197
1198 *self.watcher.write() = Some(watcher);
1199
1200 let config_path_clone = config_path.clone();
1202 tokio::spawn(async move {
1203 let mut last_reload = std::time::Instant::now();
1204 let debounce_duration = Duration::from_millis(500);
1205
1206 while rx.recv().await.is_some() {
1207 if last_reload.elapsed() < debounce_duration {
1209 continue;
1210 }
1211
1212 tokio::time::sleep(Duration::from_millis(100)).await;
1214
1215 match AresConfig::load(&config_path_clone) {
1216 Ok(new_config) => {
1217 config_arc.store(Arc::new(new_config));
1218 info!("Configuration hot-reloaded successfully");
1219 last_reload = std::time::Instant::now();
1220 }
1221 Err(e) => {
1222 warn!(
1223 "Failed to hot-reload config: {}. Keeping previous config.",
1224 e
1225 );
1226 }
1227 }
1228 }
1229 });
1230
1231 info!("Configuration hot-reload watcher started");
1232 Ok(())
1233 }
1234
1235 pub fn stop_watching(&self) {
1237 *self.watcher.write() = None;
1238 info!("Configuration hot-reload watcher stopped");
1239 }
1240}
1241
1242impl Clone for AresConfigManager {
1243 fn clone(&self) -> Self {
1244 Self {
1245 config: Arc::clone(&self.config),
1246 config_path: self.config_path.clone(),
1247 watcher: RwLock::new(None), reload_tx: self.reload_tx.clone(),
1249 }
1250 }
1251}
1252
1253impl AresConfigManager {
1254 pub fn from_config(config: AresConfig) -> Self {
1257 Self {
1258 config: Arc::new(ArcSwap::from_pointee(config)),
1259 config_path: PathBuf::from("test-config.toml"),
1260 watcher: RwLock::new(None),
1261 reload_tx: None,
1262 }
1263 }
1264}
1265
1266#[cfg(test)]
1267mod tests {
1268 use super::*;
1269
1270 fn create_test_config() -> String {
1271 r#"
1272[server]
1273host = "127.0.0.1"
1274port = 3000
1275log_level = "debug"
1276
1277[auth]
1278jwt_secret_env = "TEST_JWT_SECRET"
1279jwt_access_expiry = 900
1280jwt_refresh_expiry = 604800
1281api_key_env = "TEST_API_KEY"
1282
1283[database]
1284url = "./data/test.db"
1285
1286[providers.ollama-local]
1287type = "ollama"
1288base_url = "http://localhost:11434"
1289default_model = "ministral-3:3b"
1290
1291[models.default]
1292provider = "ollama-local"
1293model = "ministral-3:3b"
1294temperature = 0.7
1295max_tokens = 512
1296
1297[tools.calculator]
1298enabled = true
1299description = "Basic calculator"
1300timeout_secs = 10
1301
1302[agents.router]
1303model = "default"
1304tools = []
1305max_tool_iterations = 5
1306
1307[workflows.default]
1308entry_agent = "router"
1309max_depth = 3
1310max_iterations = 5
1311"#
1312 .to_string()
1313 }
1314
1315 #[test]
1316 fn test_parse_config() {
1317 unsafe {
1320 std::env::set_var("TEST_JWT_SECRET", "test-secret-at-least-32-characters-long");
1321 std::env::set_var("TEST_API_KEY", "test-api-key");
1322 }
1323
1324 let content = create_test_config();
1325 let config: AresConfig = toml::from_str(&content).expect("Failed to parse config");
1326
1327 assert_eq!(config.server.host, "127.0.0.1");
1328 assert_eq!(config.server.port, 3000);
1329 assert!(config.providers.contains_key("ollama-local"));
1330 assert!(config.models.contains_key("default"));
1331 assert!(config.agents.contains_key("router"));
1332 }
1333
1334 #[test]
1335 fn test_validation_missing_provider() {
1336 unsafe {
1338 std::env::set_var("TEST_JWT_SECRET", "test-secret");
1339 std::env::set_var("TEST_API_KEY", "test-key");
1340 }
1341
1342 let content = r#"
1343[server]
1344[auth]
1345jwt_secret_env = "TEST_JWT_SECRET"
1346api_key_env = "TEST_API_KEY"
1347[database]
1348[models.test]
1349provider = "nonexistent"
1350model = "test"
1351"#;
1352
1353 let config: AresConfig = toml::from_str(content).unwrap();
1354 let result = config.validate();
1355
1356 assert!(matches!(result, Err(ConfigError::MissingProvider(_, _))));
1357 }
1358
1359 #[test]
1360 fn test_validation_missing_model() {
1361 unsafe {
1363 std::env::set_var("TEST_JWT_SECRET", "test-secret");
1364 std::env::set_var("TEST_API_KEY", "test-key");
1365 }
1366
1367 let content = r#"
1368[server]
1369[auth]
1370jwt_secret_env = "TEST_JWT_SECRET"
1371api_key_env = "TEST_API_KEY"
1372[database]
1373[providers.test]
1374type = "ollama"
1375default_model = "ministral-3:3b"
1376[agents.test]
1377model = "nonexistent"
1378"#;
1379
1380 let config: AresConfig = toml::from_str(content).unwrap();
1381 let result = config.validate();
1382
1383 assert!(matches!(result, Err(ConfigError::MissingModel(_, _))));
1384 }
1385
1386 #[test]
1387 fn test_validation_missing_tool() {
1388 unsafe {
1390 std::env::set_var("TEST_JWT_SECRET", "test-secret");
1391 std::env::set_var("TEST_API_KEY", "test-key");
1392 }
1393
1394 let content = r#"
1395[server]
1396[auth]
1397jwt_secret_env = "TEST_JWT_SECRET"
1398api_key_env = "TEST_API_KEY"
1399[database]
1400[providers.test]
1401type = "ollama"
1402default_model = "ministral-3:3b"
1403[models.default]
1404provider = "test"
1405model = "ministral-3:3b"
1406[agents.test]
1407model = "default"
1408tools = ["nonexistent_tool"]
1409"#;
1410
1411 let config: AresConfig = toml::from_str(content).unwrap();
1412 let result = config.validate();
1413
1414 assert!(matches!(result, Err(ConfigError::MissingTool(_, _))));
1415 }
1416
1417 #[test]
1418 fn test_validation_missing_workflow_agent() {
1419 unsafe {
1421 std::env::set_var("TEST_JWT_SECRET", "test-secret");
1422 std::env::set_var("TEST_API_KEY", "test-key");
1423 }
1424
1425 let content = r#"
1426[server]
1427[auth]
1428jwt_secret_env = "TEST_JWT_SECRET"
1429api_key_env = "TEST_API_KEY"
1430[database]
1431[workflows.test]
1432entry_agent = "nonexistent_agent"
1433"#;
1434
1435 let config: AresConfig = toml::from_str(content).unwrap();
1436 let result = config.validate();
1437
1438 assert!(matches!(result, Err(ConfigError::MissingAgent(_, _))));
1439 }
1440
1441 #[test]
1442 fn test_get_provider() {
1443 let content = create_test_config();
1444 let config: AresConfig = toml::from_str(&content).unwrap();
1445
1446 assert!(config.get_provider("ollama-local").is_some());
1447 assert!(config.get_provider("nonexistent").is_none());
1448 }
1449
1450 #[test]
1451 fn test_get_model() {
1452 let content = create_test_config();
1453 let config: AresConfig = toml::from_str(&content).unwrap();
1454
1455 assert!(config.get_model("default").is_some());
1456 assert!(config.get_model("nonexistent").is_none());
1457 }
1458
1459 #[test]
1460 fn test_get_agent() {
1461 let content = create_test_config();
1462 let config: AresConfig = toml::from_str(&content).unwrap();
1463
1464 assert!(config.get_agent("router").is_some());
1465 assert!(config.get_agent("nonexistent").is_none());
1466 }
1467
1468 #[test]
1469 fn test_get_tool() {
1470 let content = create_test_config();
1471 let config: AresConfig = toml::from_str(&content).unwrap();
1472
1473 assert!(config.get_tool("calculator").is_some());
1474 assert!(config.get_tool("nonexistent").is_none());
1475 }
1476
1477 #[test]
1478 fn test_enabled_tools() {
1479 let content = r#"
1480[server]
1481[auth]
1482jwt_secret_env = "TEST_JWT_SECRET"
1483api_key_env = "TEST_API_KEY"
1484[database]
1485[tools.enabled_tool]
1486enabled = true
1487[tools.disabled_tool]
1488enabled = false
1489"#;
1490
1491 let config: AresConfig = toml::from_str(content).unwrap();
1492 let enabled = config.enabled_tools();
1493
1494 assert!(enabled.contains(&"enabled_tool"));
1495 assert!(!enabled.contains(&"disabled_tool"));
1496 }
1497
1498 #[test]
1499 fn test_defaults() {
1500 let content = r#"
1501[server]
1502[auth]
1503jwt_secret_env = "TEST_JWT_SECRET"
1504api_key_env = "TEST_API_KEY"
1505[database]
1506"#;
1507
1508 let config: AresConfig = toml::from_str(content).unwrap();
1509
1510 assert_eq!(config.server.host, "127.0.0.1");
1512 assert_eq!(config.server.port, 3000);
1513 assert_eq!(config.server.log_level, "info");
1514
1515 assert_eq!(config.auth.jwt_access_expiry, 900);
1517 assert_eq!(config.auth.jwt_refresh_expiry, 604800);
1518
1519 assert_eq!(config.database.url, "./data/ares.db");
1521
1522 assert_eq!(config.rag.embedding_model, "bge-small-en-v1.5");
1524 assert_eq!(config.rag.chunk_size, 200);
1525 assert_eq!(config.rag.chunk_overlap, 50);
1526 assert_eq!(config.rag.vector_store, "ares-vector");
1527 assert_eq!(config.rag.search_strategy, "semantic");
1528 }
1529
1530 #[test]
1531 fn test_config_manager_from_config() {
1532 let content = create_test_config();
1533 let config: AresConfig = toml::from_str(&content).unwrap();
1534
1535 let manager = AresConfigManager::from_config(config.clone());
1536 let loaded = manager.config();
1537
1538 assert_eq!(loaded.server.host, config.server.host);
1539 assert_eq!(loaded.server.port, config.server.port);
1540 }
1541
1542 #[test]
1543 fn test_circular_reference_detection() {
1544 unsafe {
1546 std::env::set_var("TEST_JWT_SECRET", "test-secret");
1547 std::env::set_var("TEST_API_KEY", "test-key");
1548 }
1549
1550 let content = r#"
1551[server]
1552[auth]
1553jwt_secret_env = "TEST_JWT_SECRET"
1554api_key_env = "TEST_API_KEY"
1555[database]
1556[providers.test]
1557type = "ollama"
1558default_model = "ministral-3:3b"
1559[models.default]
1560provider = "test"
1561model = "ministral-3:3b"
1562[agents.agent_a]
1563model = "default"
1564[workflows.circular]
1565entry_agent = "agent_a"
1566fallback_agent = "agent_a"
1567"#;
1568
1569 let config: AresConfig = toml::from_str(content).unwrap();
1570 let result = config.validate();
1571
1572 assert!(matches!(result, Err(ConfigError::CircularReference(_))));
1573 }
1574
1575 #[test]
1576 fn test_unused_provider_warning() {
1577 unsafe {
1579 std::env::set_var("TEST_JWT_SECRET", "test-secret");
1580 std::env::set_var("TEST_API_KEY", "test-key");
1581 }
1582
1583 let content = r#"
1584[server]
1585[auth]
1586jwt_secret_env = "TEST_JWT_SECRET"
1587api_key_env = "TEST_API_KEY"
1588[database]
1589[providers.used]
1590type = "ollama"
1591default_model = "ministral-3:3b"
1592[providers.unused]
1593type = "ollama"
1594default_model = "ministral-3:3b"
1595[models.default]
1596provider = "used"
1597model = "ministral-3:3b"
1598[agents.router]
1599model = "default"
1600"#;
1601
1602 let config: AresConfig = toml::from_str(content).unwrap();
1603 let warnings = config.validate_with_warnings().unwrap();
1604
1605 assert!(warnings
1606 .iter()
1607 .any(|w| w.kind == ConfigWarningKind::UnusedProvider && w.message.contains("unused")));
1608 }
1609
1610 #[test]
1611 fn test_unused_model_warning() {
1612 unsafe {
1614 std::env::set_var("TEST_JWT_SECRET", "test-secret");
1615 std::env::set_var("TEST_API_KEY", "test-key");
1616 }
1617
1618 let content = r#"
1619[server]
1620[auth]
1621jwt_secret_env = "TEST_JWT_SECRET"
1622api_key_env = "TEST_API_KEY"
1623[database]
1624[providers.test]
1625type = "ollama"
1626default_model = "ministral-3:3b"
1627[models.used]
1628provider = "test"
1629model = "ministral-3:3b"
1630[models.unused]
1631provider = "test"
1632model = "other"
1633[agents.router]
1634model = "used"
1635"#;
1636
1637 let config: AresConfig = toml::from_str(content).unwrap();
1638 let warnings = config.validate_with_warnings().unwrap();
1639
1640 assert!(warnings
1641 .iter()
1642 .any(|w| w.kind == ConfigWarningKind::UnusedModel && w.message.contains("unused")));
1643 }
1644
1645 #[test]
1646 fn test_unused_tool_warning() {
1647 unsafe {
1649 std::env::set_var("TEST_JWT_SECRET", "test-secret");
1650 std::env::set_var("TEST_API_KEY", "test-key");
1651 }
1652
1653 let content = r#"
1654[server]
1655[auth]
1656jwt_secret_env = "TEST_JWT_SECRET"
1657api_key_env = "TEST_API_KEY"
1658[database]
1659[providers.test]
1660type = "ollama"
1661default_model = "ministral-3:3b"
1662[models.default]
1663provider = "test"
1664model = "ministral-3:3b"
1665[tools.used_tool]
1666enabled = true
1667[tools.unused_tool]
1668enabled = true
1669[agents.router]
1670model = "default"
1671tools = ["used_tool"]
1672"#;
1673
1674 let config: AresConfig = toml::from_str(content).unwrap();
1675 let warnings = config.validate_with_warnings().unwrap();
1676
1677 assert!(warnings
1678 .iter()
1679 .any(|w| w.kind == ConfigWarningKind::UnusedTool && w.message.contains("unused_tool")));
1680 }
1681
1682 #[test]
1683 fn test_unused_agent_warning() {
1684 unsafe {
1686 std::env::set_var("TEST_JWT_SECRET", "test-secret");
1687 std::env::set_var("TEST_API_KEY", "test-key");
1688 }
1689
1690 let content = r#"
1691[server]
1692[auth]
1693jwt_secret_env = "TEST_JWT_SECRET"
1694api_key_env = "TEST_API_KEY"
1695[database]
1696[providers.test]
1697type = "ollama"
1698default_model = "ministral-3:3b"
1699[models.default]
1700provider = "test"
1701model = "ministral-3:3b"
1702[agents.router]
1703model = "default"
1704[agents.orphaned]
1705model = "default"
1706[workflows.test_flow]
1707entry_agent = "router"
1708"#;
1709
1710 let config: AresConfig = toml::from_str(content).unwrap();
1711 let warnings = config.validate_with_warnings().unwrap();
1712
1713 assert!(warnings
1714 .iter()
1715 .any(|w| w.kind == ConfigWarningKind::UnusedAgent && w.message.contains("orphaned")));
1716 }
1717
1718 #[test]
1719 fn test_no_warnings_for_fully_connected_config() {
1720 unsafe {
1722 std::env::set_var("TEST_JWT_SECRET", "test-secret");
1723 std::env::set_var("TEST_API_KEY", "test-key");
1724 }
1725
1726 let content = r#"
1727[server]
1728[auth]
1729jwt_secret_env = "TEST_JWT_SECRET"
1730api_key_env = "TEST_API_KEY"
1731[database]
1732[providers.test]
1733type = "ollama"
1734default_model = "ministral-3:3b"
1735[models.default]
1736provider = "test"
1737model = "ministral-3:3b"
1738[tools.calc]
1739enabled = true
1740[agents.router]
1741model = "default"
1742tools = ["calc"]
1743[workflows.main]
1744entry_agent = "router"
1745"#;
1746
1747 let config: AresConfig = toml::from_str(content).unwrap();
1748 let warnings = config.validate_with_warnings().unwrap();
1749
1750 assert!(
1751 warnings.is_empty(),
1752 "Expected no warnings but got: {:?}",
1753 warnings
1754 );
1755 }
1756}