1use arc_swap::ArcSwap;
12use notify::{Event, RecommendedWatcher, RecursiveMode, Watcher};
13use parking_lot::RwLock;
14use serde::{Deserialize, Serialize};
15use std::collections::HashMap;
16use std::fs;
17use std::path::{Path, PathBuf};
18use std::sync::Arc;
19use std::time::Duration;
20use tokio::sync::mpsc;
21use tracing::{error, info, warn};
22
23#[derive(Debug, Clone, Serialize, Deserialize)]
25pub struct AresConfig {
26 pub server: ServerConfig,
28
29 pub auth: AuthConfig,
31
32 pub database: DatabaseConfig,
34
35 #[serde(default)]
37 pub providers: HashMap<String, ProviderConfig>,
38
39 #[serde(default)]
42 pub models: HashMap<String, ModelConfig>,
43
44 #[serde(default)]
47 pub tools: HashMap<String, ToolConfig>,
48
49 #[serde(default)]
52 pub agents: HashMap<String, AgentConfig>,
53
54 #[serde(default)]
57 pub workflows: HashMap<String, WorkflowConfig>,
58
59 #[serde(default)]
61 pub rag: RagConfig,
62
63 #[serde(default)]
65 pub config: DynamicConfigPaths,
66}
67
68#[derive(Debug, Clone, Serialize, Deserialize)]
72pub struct ServerConfig {
73 #[serde(default = "default_host")]
75 pub host: String,
76
77 #[serde(default = "default_port")]
79 pub port: u16,
80
81 #[serde(default = "default_log_level")]
83 pub log_level: String,
84
85 #[serde(default = "default_cors_origins")]
88 pub cors_origins: Vec<String>,
89
90 #[serde(default = "default_rate_limit")]
92 pub rate_limit_per_second: u32,
93
94 #[serde(default = "default_rate_limit_burst")]
96 pub rate_limit_burst: u32,
97}
98
99fn default_host() -> String {
100 "127.0.0.1".to_string()
101}
102
103fn default_port() -> u16 {
104 3000
105}
106
107fn default_log_level() -> String {
108 "info".to_string()
109}
110
111fn default_cors_origins() -> Vec<String> {
112 vec!["*".to_string()]
113}
114
115fn default_rate_limit() -> u32 {
116 100
117}
118
119fn default_rate_limit_burst() -> u32 {
120 10
121}
122
123impl Default for ServerConfig {
124 fn default() -> Self {
125 Self {
126 host: default_host(),
127 port: default_port(),
128 log_level: default_log_level(),
129 cors_origins: default_cors_origins(),
130 rate_limit_per_second: default_rate_limit(),
131 rate_limit_burst: default_rate_limit_burst(),
132 }
133 }
134}
135
136#[derive(Debug, Clone, Serialize, Deserialize)]
140pub struct AuthConfig {
141 pub jwt_secret_env: String,
143
144 #[serde(default = "default_jwt_access_expiry")]
146 pub jwt_access_expiry: i64,
147
148 #[serde(default = "default_jwt_refresh_expiry")]
150 pub jwt_refresh_expiry: i64,
151
152 pub api_key_env: String,
154}
155
156fn default_jwt_access_expiry() -> i64 {
157 900
158}
159
160fn default_jwt_refresh_expiry() -> i64 {
161 604800
162}
163
164impl Default for AuthConfig {
165 fn default() -> Self {
166 Self {
167 jwt_secret_env: "JWT_SECRET".to_string(),
168 jwt_access_expiry: default_jwt_access_expiry(),
169 jwt_refresh_expiry: default_jwt_refresh_expiry(),
170 api_key_env: "API_KEY".to_string(),
171 }
172 }
173}
174
175#[derive(Debug, Clone, Serialize, Deserialize)]
179pub struct DatabaseConfig {
180 #[serde(default = "default_database_url")]
182 pub url: String,
183
184 pub turso_url_env: Option<String>,
186
187 pub turso_token_env: Option<String>,
189
190 pub qdrant: Option<QdrantConfig>,
192}
193
194fn default_database_url() -> String {
195 "./data/ares.db".to_string()
196}
197
198impl Default for DatabaseConfig {
199 fn default() -> Self {
200 Self {
201 url: default_database_url(),
202 turso_url_env: None,
203 turso_token_env: None,
204 qdrant: None,
205 }
206 }
207}
208
209#[derive(Debug, Clone, Serialize, Deserialize)]
211pub struct QdrantConfig {
212 #[serde(default = "default_qdrant_url")]
214 pub url: String,
215
216 pub api_key_env: Option<String>,
218}
219
220fn default_qdrant_url() -> String {
221 "http://localhost:6334".to_string()
222}
223
224impl Default for QdrantConfig {
225 fn default() -> Self {
226 Self {
227 url: default_qdrant_url(),
228 api_key_env: None,
229 }
230 }
231}
232
233#[derive(Debug, Clone, Serialize, Deserialize)]
237#[serde(tag = "type", rename_all = "lowercase")]
238pub enum ProviderConfig {
239 Ollama {
241 #[serde(default = "default_ollama_url")]
243 base_url: String,
244 default_model: String,
246 },
247 OpenAI {
249 api_key_env: String,
251 #[serde(default = "default_openai_base")]
253 api_base: String,
254 default_model: String,
256 },
257 LlamaCpp {
259 model_path: String,
261 #[serde(default = "default_n_ctx")]
263 n_ctx: u32,
264 #[serde(default = "default_n_threads")]
266 n_threads: u32,
267 #[serde(default = "default_max_tokens")]
269 max_tokens: u32,
270 },
271}
272
273fn default_ollama_url() -> String {
274 "http://localhost:11434".to_string()
275}
276
277fn default_openai_base() -> String {
278 "https://api.openai.com/v1".to_string()
279}
280
281fn default_n_ctx() -> u32 {
282 4096
283}
284
285fn default_n_threads() -> u32 {
286 4
287}
288
289fn default_max_tokens() -> u32 {
290 512
291}
292
293#[derive(Debug, Clone, Serialize, Deserialize)]
297pub struct ModelConfig {
298 pub provider: String,
300
301 pub model: String,
303
304 #[serde(default = "default_temperature")]
306 pub temperature: f32,
307
308 #[serde(default = "default_model_max_tokens")]
310 pub max_tokens: u32,
311
312 pub top_p: Option<f32>,
314
315 pub frequency_penalty: Option<f32>,
317
318 pub presence_penalty: Option<f32>,
320}
321
322fn default_temperature() -> f32 {
323 0.7
324}
325
326fn default_model_max_tokens() -> u32 {
327 512
328}
329
330#[derive(Debug, Clone, Serialize, Deserialize)]
334pub struct ToolConfig {
335 #[serde(default = "default_true")]
337 pub enabled: bool,
338
339 #[serde(default)]
341 pub description: Option<String>,
342
343 #[serde(default = "default_tool_timeout")]
345 pub timeout_secs: u64,
346
347 #[serde(flatten)]
349 pub extra: HashMap<String, toml::Value>,
350}
351
352fn default_true() -> bool {
353 true
354}
355
356fn default_tool_timeout() -> u64 {
357 30
358}
359
360impl Default for ToolConfig {
361 fn default() -> Self {
362 Self {
363 enabled: true,
364 description: None,
365 timeout_secs: default_tool_timeout(),
366 extra: HashMap::new(),
367 }
368 }
369}
370
371#[derive(Debug, Clone, Serialize, Deserialize)]
375pub struct AgentConfig {
376 pub model: String,
378
379 #[serde(default)]
381 pub system_prompt: Option<String>,
382
383 #[serde(default)]
385 pub tools: Vec<String>,
386
387 #[serde(default = "default_max_tool_iterations")]
389 pub max_tool_iterations: usize,
390
391 #[serde(default)]
393 pub parallel_tools: bool,
394
395 #[serde(flatten)]
397 pub extra: HashMap<String, toml::Value>,
398}
399
400fn default_max_tool_iterations() -> usize {
401 10
402}
403
404#[derive(Debug, Clone, Serialize, Deserialize)]
408pub struct WorkflowConfig {
409 pub entry_agent: String,
411
412 pub fallback_agent: Option<String>,
414
415 #[serde(default = "default_max_depth")]
417 pub max_depth: u8,
418
419 #[serde(default = "default_max_iterations")]
421 pub max_iterations: u8,
422
423 #[serde(default)]
425 pub parallel_subagents: bool,
426}
427
428fn default_max_depth() -> u8 {
429 3
430}
431
432fn default_max_iterations() -> u8 {
433 5
434}
435
436#[derive(Debug, Clone, Serialize, Deserialize)]
440pub struct RagConfig {
441 #[serde(default = "default_vector_store")]
444 pub vector_store: String,
445
446 #[serde(default = "default_vector_path")]
448 pub vector_path: String,
449
450 #[serde(default = "default_embedding_model")]
455 pub embedding_model: String,
456
457 #[serde(default)]
459 pub sparse_embeddings: bool,
460
461 #[serde(default = "default_sparse_model")]
463 pub sparse_model: String,
464
465 #[serde(default = "default_chunking_strategy")]
468 pub chunking_strategy: String,
469
470 #[serde(default = "default_chunk_size")]
472 pub chunk_size: usize,
473
474 #[serde(default = "default_chunk_overlap")]
476 pub chunk_overlap: usize,
477
478 #[serde(default = "default_min_chunk_size")]
480 pub min_chunk_size: usize,
481
482 #[serde(default = "default_search_strategy")]
485 pub search_strategy: String,
486
487 #[serde(default = "default_search_limit")]
489 pub search_limit: usize,
490
491 #[serde(default)]
493 pub search_threshold: f32,
494
495 #[serde(default)]
497 pub hybrid_weights: Option<HybridWeightsConfig>,
498
499 #[serde(default)]
502 pub rerank_enabled: bool,
503
504 #[serde(default = "default_reranker_model")]
507 pub reranker_model: String,
508
509 #[serde(default = "default_rerank_weight")]
511 pub rerank_weight: f32,
512}
513
514#[derive(Debug, Clone, Serialize, Deserialize)]
516pub struct HybridWeightsConfig {
517 #[serde(default = "default_semantic_weight")]
519 pub semantic: f32,
520 #[serde(default = "default_bm25_weight")]
522 pub bm25: f32,
523 #[serde(default = "default_fuzzy_weight")]
525 pub fuzzy: f32,
526}
527
528impl Default for HybridWeightsConfig {
529 fn default() -> Self {
530 Self {
531 semantic: 0.5,
532 bm25: 0.3,
533 fuzzy: 0.2,
534 }
535 }
536}
537
538fn default_semantic_weight() -> f32 {
539 0.5
540}
541
542fn default_bm25_weight() -> f32 {
543 0.3
544}
545
546fn default_fuzzy_weight() -> f32 {
547 0.2
548}
549
550fn default_vector_store() -> String {
551 "ares-vector".to_string()
552}
553
554fn default_vector_path() -> String {
555 "./data/vectors".to_string()
556}
557
558fn default_embedding_model() -> String {
559 "bge-small-en-v1.5".to_string()
560}
561
562fn default_sparse_model() -> String {
563 "splade-pp-en-v1".to_string()
564}
565
566fn default_chunking_strategy() -> String {
567 "word".to_string()
568}
569
570fn default_chunk_size() -> usize {
571 200
572}
573
574fn default_chunk_overlap() -> usize {
575 50
576}
577
578fn default_min_chunk_size() -> usize {
579 20
580}
581
582fn default_search_strategy() -> String {
583 "semantic".to_string()
584}
585
586fn default_search_limit() -> usize {
587 10
588}
589
590fn default_reranker_model() -> String {
591 "bge-reranker-base".to_string()
592}
593
594fn default_rerank_weight() -> f32 {
595 0.6
596}
597
598impl Default for RagConfig {
599 fn default() -> Self {
600 Self {
601 vector_store: default_vector_store(),
602 vector_path: default_vector_path(),
603 embedding_model: default_embedding_model(),
604 sparse_embeddings: false,
605 sparse_model: default_sparse_model(),
606 chunking_strategy: default_chunking_strategy(),
607 chunk_size: default_chunk_size(),
608 chunk_overlap: default_chunk_overlap(),
609 min_chunk_size: default_min_chunk_size(),
610 search_strategy: default_search_strategy(),
611 search_limit: default_search_limit(),
612 search_threshold: 0.0,
613 hybrid_weights: None,
614 rerank_enabled: false,
615 reranker_model: default_reranker_model(),
616 rerank_weight: default_rerank_weight(),
617 }
618 }
619}
620
621#[derive(Debug, Clone, Serialize, Deserialize)]
629pub struct DynamicConfigPaths {
630 #[serde(default = "default_agents_dir")]
632 pub agents_dir: std::path::PathBuf,
633
634 #[serde(default = "default_workflows_dir")]
636 pub workflows_dir: std::path::PathBuf,
637
638 #[serde(default = "default_models_dir")]
640 pub models_dir: std::path::PathBuf,
641
642 #[serde(default = "default_tools_dir")]
644 pub tools_dir: std::path::PathBuf,
645
646 #[serde(default = "default_mcps_dir")]
648 pub mcps_dir: std::path::PathBuf,
649
650 #[serde(default = "default_hot_reload")]
652 pub hot_reload: bool,
653
654 #[serde(default = "default_watch_interval")]
656 pub watch_interval_ms: u64,
657}
658
659fn default_agents_dir() -> std::path::PathBuf {
660 std::path::PathBuf::from("config/agents")
661}
662
663fn default_workflows_dir() -> std::path::PathBuf {
664 std::path::PathBuf::from("config/workflows")
665}
666
667fn default_models_dir() -> std::path::PathBuf {
668 std::path::PathBuf::from("config/models")
669}
670
671fn default_tools_dir() -> std::path::PathBuf {
672 std::path::PathBuf::from("config/tools")
673}
674
675fn default_mcps_dir() -> std::path::PathBuf {
676 std::path::PathBuf::from("config/mcps")
677}
678
679fn default_hot_reload() -> bool {
680 true
681}
682
683fn default_watch_interval() -> u64 {
684 1000
685}
686
687impl Default for DynamicConfigPaths {
688 fn default() -> Self {
689 Self {
690 agents_dir: default_agents_dir(),
691 workflows_dir: default_workflows_dir(),
692 models_dir: default_models_dir(),
693 tools_dir: default_tools_dir(),
694 mcps_dir: default_mcps_dir(),
695 hot_reload: default_hot_reload(),
696 watch_interval_ms: default_watch_interval(),
697 }
698 }
699}
700
701#[derive(Debug, Clone)]
705pub struct ConfigWarning {
706 pub kind: ConfigWarningKind,
708
709 pub message: String,
711}
712
713#[derive(Debug, Clone, PartialEq)]
715pub enum ConfigWarningKind {
716 UnusedProvider,
718
719 UnusedModel,
721
722 UnusedTool,
724
725 UnusedAgent,
727}
728
729impl std::fmt::Display for ConfigWarning {
730 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
731 write!(f, "{}", self.message)
732 }
733}
734
735#[derive(Debug, thiserror::Error)]
737pub enum ConfigError {
738 #[error("Configuration file not found: {0}")]
740 FileNotFound(PathBuf),
741
742 #[error("Failed to read configuration file: {0}")]
744 ReadError(#[from] std::io::Error),
745
746 #[error("Failed to parse TOML: {0}")]
748 ParseError(#[from] toml::de::Error),
749
750 #[error("Validation error: {0}")]
752 ValidationError(String),
753
754 #[error("Environment variable '{0}' referenced in config is not set")]
756 MissingEnvVar(String),
757
758 #[error("Provider '{0}' referenced by model '{1}' does not exist")]
760 MissingProvider(String, String),
761
762 #[error("Model '{0}' referenced by agent '{1}' does not exist")]
764 MissingModel(String, String),
765
766 #[error("Agent '{0}' referenced by workflow '{1}' does not exist")]
768 MissingAgent(String, String),
769
770 #[error("Tool '{0}' referenced by agent '{1}' does not exist")]
772 MissingTool(String, String),
773
774 #[error("Circular reference detected: {0}")]
776 CircularReference(String),
777
778 #[error("Watch error: {0}")]
780 WatchError(#[from] notify::Error),
781}
782
783impl AresConfig {
784 pub fn load<P: AsRef<Path>>(path: P) -> Result<Self, ConfigError> {
791 let path = path.as_ref();
792
793 if !path.exists() {
794 return Err(ConfigError::FileNotFound(path.to_path_buf()));
795 }
796
797 let content = fs::read_to_string(path)?;
798 let config: AresConfig = toml::from_str(&content)?;
799
800 config.validate()?;
802
803 Ok(config)
804 }
805
806 pub fn load_unchecked<P: AsRef<Path>>(path: P) -> Result<Self, ConfigError> {
812 let path = path.as_ref();
813
814 if !path.exists() {
815 return Err(ConfigError::FileNotFound(path.to_path_buf()));
816 }
817
818 let content = fs::read_to_string(path)?;
819 let config: AresConfig = toml::from_str(&content)?;
820
821 Ok(config)
822 }
823
824 pub fn validate(&self) -> Result<(), ConfigError> {
826 self.validate_env_var(&self.auth.jwt_secret_env)?;
828 self.validate_env_var(&self.auth.api_key_env)?;
829
830 if let Some(ref env) = self.database.turso_url_env {
832 self.validate_env_var(env)?;
833 }
834 if let Some(ref env) = self.database.turso_token_env {
835 self.validate_env_var(env)?;
836 }
837 if let Some(ref qdrant) = self.database.qdrant {
838 if let Some(ref env) = qdrant.api_key_env {
839 self.validate_env_var(env)?;
840 }
841 }
842
843 for (name, provider) in &self.providers {
845 match provider {
846 ProviderConfig::OpenAI { api_key_env, .. } => {
847 self.validate_env_var(api_key_env)?;
848 }
849 ProviderConfig::LlamaCpp { model_path, .. } => {
850 if !Path::new(model_path).exists() {
852 return Err(ConfigError::ValidationError(format!(
853 "LlamaCpp model path does not exist: {} (provider: {})",
854 model_path, name
855 )));
856 }
857 }
858 ProviderConfig::Ollama { .. } => {
859 }
861 }
862 }
863
864 for (model_name, model_config) in &self.models {
866 if !self.providers.contains_key(&model_config.provider) {
867 return Err(ConfigError::MissingProvider(
868 model_config.provider.clone(),
869 model_name.clone(),
870 ));
871 }
872 }
873
874 for (agent_name, agent_config) in &self.agents {
876 if !self.models.contains_key(&agent_config.model) {
877 return Err(ConfigError::MissingModel(
878 agent_config.model.clone(),
879 agent_name.clone(),
880 ));
881 }
882
883 for tool_name in &agent_config.tools {
884 if !self.tools.contains_key(tool_name) {
885 return Err(ConfigError::MissingTool(
886 tool_name.clone(),
887 agent_name.clone(),
888 ));
889 }
890 }
891 }
892
893 for (workflow_name, workflow_config) in &self.workflows {
895 if !self.agents.contains_key(&workflow_config.entry_agent) {
896 return Err(ConfigError::MissingAgent(
897 workflow_config.entry_agent.clone(),
898 workflow_name.clone(),
899 ));
900 }
901
902 if let Some(ref fallback) = workflow_config.fallback_agent {
903 if !self.agents.contains_key(fallback) {
904 return Err(ConfigError::MissingAgent(
905 fallback.clone(),
906 workflow_name.clone(),
907 ));
908 }
909 }
910 }
911
912 self.detect_circular_references()?;
914
915 Ok(())
916 }
917
918 fn detect_circular_references(&self) -> Result<(), ConfigError> {
923 use std::collections::HashSet;
924
925 for (workflow_name, workflow_config) in &self.workflows {
926 let mut visited = HashSet::new();
927 let mut current = Some(workflow_config.entry_agent.as_str());
928
929 while let Some(agent_name) = current {
930 if visited.contains(agent_name) {
931 return Err(ConfigError::CircularReference(format!(
932 "Circular reference detected in workflow '{}': agent '{}' appears multiple times in the chain",
933 workflow_name, agent_name
934 )));
935 }
936 visited.insert(agent_name);
937
938 current = None;
941
942 if let Some(ref fallback) = workflow_config.fallback_agent {
944 if fallback == &workflow_config.entry_agent {
945 return Err(ConfigError::CircularReference(format!(
946 "Workflow '{}' has entry_agent '{}' that equals fallback_agent",
947 workflow_name, workflow_config.entry_agent
948 )));
949 }
950 }
951 }
952 }
953
954 Ok(())
955 }
956
957 pub fn validate_with_warnings(&self) -> Result<Vec<ConfigWarning>, ConfigError> {
961 self.validate()?;
963
964 let mut warnings = Vec::new();
966
967 warnings.extend(self.check_unused_providers());
969
970 warnings.extend(self.check_unused_models());
972
973 warnings.extend(self.check_unused_tools());
975
976 warnings.extend(self.check_unused_agents());
978
979 Ok(warnings)
980 }
981
982 fn check_unused_providers(&self) -> Vec<ConfigWarning> {
984 use std::collections::HashSet;
985
986 let referenced: HashSet<_> = self.models.values().map(|m| m.provider.as_str()).collect();
987
988 self.providers
989 .keys()
990 .filter(|name| !referenced.contains(name.as_str()))
991 .map(|name| ConfigWarning {
992 kind: ConfigWarningKind::UnusedProvider,
993 message: format!(
994 "Provider '{}' is defined but not referenced by any model",
995 name
996 ),
997 })
998 .collect()
999 }
1000
1001 fn check_unused_models(&self) -> Vec<ConfigWarning> {
1003 use std::collections::HashSet;
1004
1005 let referenced: HashSet<_> = self.agents.values().map(|a| a.model.as_str()).collect();
1006
1007 self.models
1008 .keys()
1009 .filter(|name| !referenced.contains(name.as_str()))
1010 .map(|name| ConfigWarning {
1011 kind: ConfigWarningKind::UnusedModel,
1012 message: format!(
1013 "Model '{}' is defined but not referenced by any agent",
1014 name
1015 ),
1016 })
1017 .collect()
1018 }
1019
1020 fn check_unused_tools(&self) -> Vec<ConfigWarning> {
1022 use std::collections::HashSet;
1023
1024 let referenced: HashSet<_> = self
1025 .agents
1026 .values()
1027 .flat_map(|a| a.tools.iter().map(|t| t.as_str()))
1028 .collect();
1029
1030 self.tools
1031 .keys()
1032 .filter(|name| !referenced.contains(name.as_str()))
1033 .map(|name| ConfigWarning {
1034 kind: ConfigWarningKind::UnusedTool,
1035 message: format!("Tool '{}' is defined but not referenced by any agent", name),
1036 })
1037 .collect()
1038 }
1039
1040 fn check_unused_agents(&self) -> Vec<ConfigWarning> {
1042 use std::collections::HashSet;
1043
1044 let referenced: HashSet<_> = self
1045 .workflows
1046 .values()
1047 .flat_map(|w| {
1048 let mut refs = vec![w.entry_agent.as_str()];
1049 if let Some(ref fallback) = w.fallback_agent {
1050 refs.push(fallback.as_str());
1051 }
1052 refs
1053 })
1054 .collect();
1055
1056 let system_agents: HashSet<&str> = ["orchestrator", "router"].into_iter().collect();
1058
1059 self.agents
1060 .keys()
1061 .filter(|name| {
1062 !referenced.contains(name.as_str()) && !system_agents.contains(name.as_str())
1063 })
1064 .map(|name| ConfigWarning {
1065 kind: ConfigWarningKind::UnusedAgent,
1066 message: format!(
1067 "Agent '{}' is defined but not referenced by any workflow",
1068 name
1069 ),
1070 })
1071 .collect()
1072 }
1073
1074 fn validate_env_var(&self, name: &str) -> Result<(), ConfigError> {
1075 std::env::var(name).map_err(|_| ConfigError::MissingEnvVar(name.to_string()))?;
1076 Ok(())
1077 }
1078
1079 pub fn resolve_env(&self, env_name: &str) -> Option<String> {
1081 std::env::var(env_name).ok()
1082 }
1083
1084 const JWT_SECRET_MIN_LENGTH: usize = 32;
1086
1087 pub fn jwt_secret(&self) -> Result<String, ConfigError> {
1094 let secret = self
1095 .resolve_env(&self.auth.jwt_secret_env)
1096 .ok_or_else(|| ConfigError::MissingEnvVar(self.auth.jwt_secret_env.clone()))?;
1097
1098 if secret.len() < Self::JWT_SECRET_MIN_LENGTH {
1099 return Err(ConfigError::ValidationError(format!(
1100 "JWT_SECRET must be at least {} characters for security (current: {} chars). \
1101 Use a cryptographically random string, e.g.: openssl rand -base64 32",
1102 Self::JWT_SECRET_MIN_LENGTH,
1103 secret.len()
1104 )));
1105 }
1106
1107 Ok(secret)
1108 }
1109
1110 pub fn api_key(&self) -> Result<String, ConfigError> {
1112 self.resolve_env(&self.auth.api_key_env)
1113 .ok_or_else(|| ConfigError::MissingEnvVar(self.auth.api_key_env.clone()))
1114 }
1115
1116 pub fn get_provider(&self, name: &str) -> Option<&ProviderConfig> {
1118 self.providers.get(name)
1119 }
1120
1121 pub fn get_model(&self, name: &str) -> Option<&ModelConfig> {
1123 self.models.get(name)
1124 }
1125
1126 pub fn get_agent(&self, name: &str) -> Option<&AgentConfig> {
1128 self.agents.get(name)
1129 }
1130
1131 pub fn get_tool(&self, name: &str) -> Option<&ToolConfig> {
1133 self.tools.get(name)
1134 }
1135
1136 pub fn get_workflow(&self, name: &str) -> Option<&WorkflowConfig> {
1138 self.workflows.get(name)
1139 }
1140
1141 pub fn enabled_tools(&self) -> Vec<&str> {
1143 self.tools
1144 .iter()
1145 .filter(|(_, config)| config.enabled)
1146 .map(|(name, _)| name.as_str())
1147 .collect()
1148 }
1149
1150 pub fn agent_tools(&self, agent_name: &str) -> Vec<&str> {
1152 self.get_agent(agent_name)
1153 .map(|agent| {
1154 agent
1155 .tools
1156 .iter()
1157 .filter(|t| self.get_tool(t).map(|tc| tc.enabled).unwrap_or(false))
1158 .map(|s| s.as_str())
1159 .collect()
1160 })
1161 .unwrap_or_default()
1162 }
1163}
1164
1165pub struct AresConfigManager {
1169 config: Arc<ArcSwap<AresConfig>>,
1170 config_path: PathBuf,
1171 watcher: RwLock<Option<RecommendedWatcher>>,
1172 reload_tx: Option<mpsc::UnboundedSender<()>>,
1173}
1174
1175impl AresConfigManager {
1176 pub fn new<P: AsRef<Path>>(path: P) -> Result<Self, ConfigError> {
1182 let path = path.as_ref();
1184 let path = if path.is_absolute() {
1185 path.to_path_buf()
1186 } else {
1187 std::env::current_dir()
1188 .map_err(ConfigError::ReadError)?
1189 .join(path)
1190 };
1191
1192 let config = AresConfig::load(&path)?;
1193
1194 Ok(Self {
1195 config: Arc::new(ArcSwap::from_pointee(config)),
1196 config_path: path,
1197 watcher: RwLock::new(None),
1198 reload_tx: None,
1199 })
1200 }
1201
1202 pub fn config(&self) -> Arc<AresConfig> {
1204 self.config.load_full()
1205 }
1206
1207 pub fn reload(&self) -> Result<(), ConfigError> {
1209 info!("Reloading configuration from {:?}", self.config_path);
1210
1211 let new_config = AresConfig::load(&self.config_path)?;
1212 self.config.store(Arc::new(new_config));
1213
1214 info!("Configuration reloaded successfully");
1215 Ok(())
1216 }
1217
1218 pub fn start_watching(&mut self) -> Result<(), ConfigError> {
1220 let (tx, mut rx) = mpsc::unbounded_channel::<()>();
1221 self.reload_tx = Some(tx.clone());
1222
1223 let config_path = self.config_path.clone();
1224 let config_arc = Arc::clone(&self.config);
1225
1226 let mut watcher = notify::recommended_watcher(move |res: Result<Event, notify::Error>| {
1228 match res {
1229 Ok(event) => {
1230 if event.kind.is_modify() || event.kind.is_create() {
1231 let _ = tx.send(());
1233 }
1234 }
1235 Err(e) => {
1236 error!("Config watcher error: {:?}", e);
1237 }
1238 }
1239 })?;
1240
1241 if let Some(parent) = self.config_path.parent() {
1243 watcher.watch(parent, RecursiveMode::NonRecursive)?;
1244 }
1245
1246 *self.watcher.write() = Some(watcher);
1247
1248 let config_path_clone = config_path.clone();
1250 tokio::spawn(async move {
1251 let mut last_reload = std::time::Instant::now();
1252 let debounce_duration = Duration::from_millis(500);
1253
1254 while rx.recv().await.is_some() {
1255 if last_reload.elapsed() < debounce_duration {
1257 continue;
1258 }
1259
1260 tokio::time::sleep(Duration::from_millis(100)).await;
1262
1263 match AresConfig::load(&config_path_clone) {
1264 Ok(new_config) => {
1265 config_arc.store(Arc::new(new_config));
1266 info!("Configuration hot-reloaded successfully");
1267 last_reload = std::time::Instant::now();
1268 }
1269 Err(e) => {
1270 warn!(
1271 "Failed to hot-reload config: {}. Keeping previous config.",
1272 e
1273 );
1274 }
1275 }
1276 }
1277 });
1278
1279 info!("Configuration hot-reload watcher started");
1280 Ok(())
1281 }
1282
1283 pub fn stop_watching(&self) {
1285 *self.watcher.write() = None;
1286 info!("Configuration hot-reload watcher stopped");
1287 }
1288}
1289
1290impl Clone for AresConfigManager {
1291 fn clone(&self) -> Self {
1292 Self {
1293 config: Arc::clone(&self.config),
1294 config_path: self.config_path.clone(),
1295 watcher: RwLock::new(None), reload_tx: self.reload_tx.clone(),
1297 }
1298 }
1299}
1300
1301impl AresConfigManager {
1302 pub fn from_config(config: AresConfig) -> Self {
1305 Self {
1306 config: Arc::new(ArcSwap::from_pointee(config)),
1307 config_path: PathBuf::from("test-config.toml"),
1308 watcher: RwLock::new(None),
1309 reload_tx: None,
1310 }
1311 }
1312}
1313
1314#[cfg(test)]
1315mod tests {
1316 use super::*;
1317
1318 fn create_test_config() -> String {
1319 r#"
1320[server]
1321host = "127.0.0.1"
1322port = 3000
1323log_level = "debug"
1324
1325[auth]
1326jwt_secret_env = "TEST_JWT_SECRET"
1327jwt_access_expiry = 900
1328jwt_refresh_expiry = 604800
1329api_key_env = "TEST_API_KEY"
1330
1331[database]
1332url = "./data/test.db"
1333
1334[providers.ollama-local]
1335type = "ollama"
1336base_url = "http://localhost:11434"
1337default_model = "ministral-3:3b"
1338
1339[models.default]
1340provider = "ollama-local"
1341model = "ministral-3:3b"
1342temperature = 0.7
1343max_tokens = 512
1344
1345[tools.calculator]
1346enabled = true
1347description = "Basic calculator"
1348timeout_secs = 10
1349
1350[agents.router]
1351model = "default"
1352tools = []
1353max_tool_iterations = 5
1354
1355[workflows.default]
1356entry_agent = "router"
1357max_depth = 3
1358max_iterations = 5
1359"#
1360 .to_string()
1361 }
1362
1363 #[test]
1364 fn test_parse_config() {
1365 unsafe {
1368 std::env::set_var(
1369 "TEST_JWT_SECRET",
1370 "test-secret-at-least-32-characters-long-at-least-32-characters-long",
1371 );
1372 std::env::set_var("TEST_API_KEY", "test-api-key");
1373 }
1374
1375 let content = create_test_config();
1376 let config: AresConfig = toml::from_str(&content).expect("Failed to parse config");
1377
1378 assert_eq!(config.server.host, "127.0.0.1");
1379 assert_eq!(config.server.port, 3000);
1380 assert!(config.providers.contains_key("ollama-local"));
1381 assert!(config.models.contains_key("default"));
1382 assert!(config.agents.contains_key("router"));
1383 }
1384
1385 #[test]
1386 fn test_validation_missing_provider() {
1387 unsafe {
1389 std::env::set_var("TEST_JWT_SECRET", "test-secret-at-least-32-characters-long");
1390 std::env::set_var("TEST_API_KEY", "test-key");
1391 }
1392
1393 let content = r#"
1394[server]
1395[auth]
1396jwt_secret_env = "TEST_JWT_SECRET"
1397api_key_env = "TEST_API_KEY"
1398[database]
1399[models.test]
1400provider = "nonexistent"
1401model = "test"
1402"#;
1403
1404 let config: AresConfig = toml::from_str(content).unwrap();
1405 let result = config.validate();
1406
1407 assert!(matches!(result, Err(ConfigError::MissingProvider(_, _))));
1408 }
1409
1410 #[test]
1411 fn test_validation_missing_model() {
1412 unsafe {
1414 std::env::set_var("TEST_JWT_SECRET", "test-secret-at-least-32-characters-long");
1415 std::env::set_var("TEST_API_KEY", "test-key");
1416 }
1417
1418 let content = r#"
1419[server]
1420[auth]
1421jwt_secret_env = "TEST_JWT_SECRET"
1422api_key_env = "TEST_API_KEY"
1423[database]
1424[providers.test]
1425type = "ollama"
1426default_model = "ministral-3:3b"
1427[agents.test]
1428model = "nonexistent"
1429"#;
1430
1431 let config: AresConfig = toml::from_str(content).unwrap();
1432 let result = config.validate();
1433
1434 assert!(matches!(result, Err(ConfigError::MissingModel(_, _))));
1435 }
1436
1437 #[test]
1438 fn test_validation_missing_tool() {
1439 unsafe {
1441 std::env::set_var("TEST_JWT_SECRET", "test-secret-at-least-32-characters-long");
1442 std::env::set_var("TEST_API_KEY", "test-key");
1443 }
1444
1445 let content = r#"
1446[server]
1447[auth]
1448jwt_secret_env = "TEST_JWT_SECRET"
1449api_key_env = "TEST_API_KEY"
1450[database]
1451[providers.test]
1452type = "ollama"
1453default_model = "ministral-3:3b"
1454[models.default]
1455provider = "test"
1456model = "ministral-3:3b"
1457[agents.test]
1458model = "default"
1459tools = ["nonexistent_tool"]
1460"#;
1461
1462 let config: AresConfig = toml::from_str(content).unwrap();
1463 let result = config.validate();
1464
1465 assert!(matches!(result, Err(ConfigError::MissingTool(_, _))));
1466 }
1467
1468 #[test]
1469 fn test_validation_missing_workflow_agent() {
1470 unsafe {
1472 std::env::set_var("TEST_JWT_SECRET", "test-secret-at-least-32-characters-long");
1473 std::env::set_var("TEST_API_KEY", "test-key");
1474 }
1475
1476 let content = r#"
1477[server]
1478[auth]
1479jwt_secret_env = "TEST_JWT_SECRET"
1480api_key_env = "TEST_API_KEY"
1481[database]
1482[workflows.test]
1483entry_agent = "nonexistent_agent"
1484"#;
1485
1486 let config: AresConfig = toml::from_str(content).unwrap();
1487 let result = config.validate();
1488
1489 assert!(matches!(result, Err(ConfigError::MissingAgent(_, _))));
1490 }
1491
1492 #[test]
1493 fn test_get_provider() {
1494 let content = create_test_config();
1495 let config: AresConfig = toml::from_str(&content).unwrap();
1496
1497 assert!(config.get_provider("ollama-local").is_some());
1498 assert!(config.get_provider("nonexistent").is_none());
1499 }
1500
1501 #[test]
1502 fn test_get_model() {
1503 let content = create_test_config();
1504 let config: AresConfig = toml::from_str(&content).unwrap();
1505
1506 assert!(config.get_model("default").is_some());
1507 assert!(config.get_model("nonexistent").is_none());
1508 }
1509
1510 #[test]
1511 fn test_get_agent() {
1512 let content = create_test_config();
1513 let config: AresConfig = toml::from_str(&content).unwrap();
1514
1515 assert!(config.get_agent("router").is_some());
1516 assert!(config.get_agent("nonexistent").is_none());
1517 }
1518
1519 #[test]
1520 fn test_get_tool() {
1521 let content = create_test_config();
1522 let config: AresConfig = toml::from_str(&content).unwrap();
1523
1524 assert!(config.get_tool("calculator").is_some());
1525 assert!(config.get_tool("nonexistent").is_none());
1526 }
1527
1528 #[test]
1529 fn test_enabled_tools() {
1530 let content = r#"
1531[server]
1532[auth]
1533jwt_secret_env = "TEST_JWT_SECRET"
1534api_key_env = "TEST_API_KEY"
1535[database]
1536[tools.enabled_tool]
1537enabled = true
1538[tools.disabled_tool]
1539enabled = false
1540"#;
1541
1542 let config: AresConfig = toml::from_str(content).unwrap();
1543 let enabled = config.enabled_tools();
1544
1545 assert!(enabled.contains(&"enabled_tool"));
1546 assert!(!enabled.contains(&"disabled_tool"));
1547 }
1548
1549 #[test]
1550 fn test_defaults() {
1551 let content = r#"
1552[server]
1553[auth]
1554jwt_secret_env = "TEST_JWT_SECRET"
1555api_key_env = "TEST_API_KEY"
1556[database]
1557"#;
1558
1559 let config: AresConfig = toml::from_str(content).unwrap();
1560
1561 assert_eq!(config.server.host, "127.0.0.1");
1563 assert_eq!(config.server.port, 3000);
1564 assert_eq!(config.server.log_level, "info");
1565
1566 assert_eq!(config.auth.jwt_access_expiry, 900);
1568 assert_eq!(config.auth.jwt_refresh_expiry, 604800);
1569
1570 assert_eq!(config.database.url, "./data/ares.db");
1572
1573 assert_eq!(config.rag.embedding_model, "bge-small-en-v1.5");
1575 assert_eq!(config.rag.chunk_size, 200);
1576 assert_eq!(config.rag.chunk_overlap, 50);
1577 assert_eq!(config.rag.vector_store, "ares-vector");
1578 assert_eq!(config.rag.search_strategy, "semantic");
1579 }
1580
1581 #[test]
1582 fn test_config_manager_from_config() {
1583 let content = create_test_config();
1584 let config: AresConfig = toml::from_str(&content).unwrap();
1585
1586 let manager = AresConfigManager::from_config(config.clone());
1587 let loaded = manager.config();
1588
1589 assert_eq!(loaded.server.host, config.server.host);
1590 assert_eq!(loaded.server.port, config.server.port);
1591 }
1592
1593 #[test]
1594 fn test_circular_reference_detection() {
1595 unsafe {
1597 std::env::set_var("TEST_JWT_SECRET", "test-secret-at-least-32-characters-long");
1598 std::env::set_var("TEST_API_KEY", "test-key");
1599 }
1600
1601 let content = r#"
1602[server]
1603[auth]
1604jwt_secret_env = "TEST_JWT_SECRET"
1605api_key_env = "TEST_API_KEY"
1606[database]
1607[providers.test]
1608type = "ollama"
1609default_model = "ministral-3:3b"
1610[models.default]
1611provider = "test"
1612model = "ministral-3:3b"
1613[agents.agent_a]
1614model = "default"
1615[workflows.circular]
1616entry_agent = "agent_a"
1617fallback_agent = "agent_a"
1618"#;
1619
1620 let config: AresConfig = toml::from_str(content).unwrap();
1621 let result = config.validate();
1622
1623 assert!(matches!(result, Err(ConfigError::CircularReference(_))));
1624 }
1625
1626 #[test]
1627 fn test_unused_provider_warning() {
1628 unsafe {
1630 std::env::set_var("TEST_JWT_SECRET", "test-secret-at-least-32-characters-long");
1631 std::env::set_var("TEST_API_KEY", "test-key");
1632 }
1633
1634 let content = r#"
1635[server]
1636[auth]
1637jwt_secret_env = "TEST_JWT_SECRET"
1638api_key_env = "TEST_API_KEY"
1639[database]
1640[providers.used]
1641type = "ollama"
1642default_model = "ministral-3:3b"
1643[providers.unused]
1644type = "ollama"
1645default_model = "ministral-3:3b"
1646[models.default]
1647provider = "used"
1648model = "ministral-3:3b"
1649[agents.router]
1650model = "default"
1651"#;
1652
1653 let config: AresConfig = toml::from_str(content).unwrap();
1654 let warnings = config.validate_with_warnings().unwrap();
1655
1656 assert!(warnings
1657 .iter()
1658 .any(|w| w.kind == ConfigWarningKind::UnusedProvider && w.message.contains("unused")));
1659 }
1660
1661 #[test]
1662 fn test_unused_model_warning() {
1663 unsafe {
1665 std::env::set_var("TEST_JWT_SECRET", "test-secret-at-least-32-characters-long");
1666 std::env::set_var("TEST_API_KEY", "test-key");
1667 }
1668
1669 let content = r#"
1670[server]
1671[auth]
1672jwt_secret_env = "TEST_JWT_SECRET"
1673api_key_env = "TEST_API_KEY"
1674[database]
1675[providers.test]
1676type = "ollama"
1677default_model = "ministral-3:3b"
1678[models.used]
1679provider = "test"
1680model = "ministral-3:3b"
1681[models.unused]
1682provider = "test"
1683model = "other"
1684[agents.router]
1685model = "used"
1686"#;
1687
1688 let config: AresConfig = toml::from_str(content).unwrap();
1689 let warnings = config.validate_with_warnings().unwrap();
1690
1691 assert!(warnings
1692 .iter()
1693 .any(|w| w.kind == ConfigWarningKind::UnusedModel && w.message.contains("unused")));
1694 }
1695
1696 #[test]
1697 fn test_unused_tool_warning() {
1698 unsafe {
1700 std::env::set_var("TEST_JWT_SECRET", "test-secret-at-least-32-characters-long");
1701 std::env::set_var("TEST_API_KEY", "test-key");
1702 }
1703
1704 let content = r#"
1705[server]
1706[auth]
1707jwt_secret_env = "TEST_JWT_SECRET"
1708api_key_env = "TEST_API_KEY"
1709[database]
1710[providers.test]
1711type = "ollama"
1712default_model = "ministral-3:3b"
1713[models.default]
1714provider = "test"
1715model = "ministral-3:3b"
1716[tools.used_tool]
1717enabled = true
1718[tools.unused_tool]
1719enabled = true
1720[agents.router]
1721model = "default"
1722tools = ["used_tool"]
1723"#;
1724
1725 let config: AresConfig = toml::from_str(content).unwrap();
1726 let warnings = config.validate_with_warnings().unwrap();
1727
1728 assert!(warnings
1729 .iter()
1730 .any(|w| w.kind == ConfigWarningKind::UnusedTool && w.message.contains("unused_tool")));
1731 }
1732
1733 #[test]
1734 fn test_unused_agent_warning() {
1735 unsafe {
1737 std::env::set_var("TEST_JWT_SECRET", "test-secret-at-least-32-characters-long");
1738 std::env::set_var("TEST_API_KEY", "test-key");
1739 }
1740
1741 let content = r#"
1742[server]
1743[auth]
1744jwt_secret_env = "TEST_JWT_SECRET"
1745api_key_env = "TEST_API_KEY"
1746[database]
1747[providers.test]
1748type = "ollama"
1749default_model = "ministral-3:3b"
1750[models.default]
1751provider = "test"
1752model = "ministral-3:3b"
1753[agents.router]
1754model = "default"
1755[agents.orphaned]
1756model = "default"
1757[workflows.test_flow]
1758entry_agent = "router"
1759"#;
1760
1761 let config: AresConfig = toml::from_str(content).unwrap();
1762 let warnings = config.validate_with_warnings().unwrap();
1763
1764 assert!(warnings
1765 .iter()
1766 .any(|w| w.kind == ConfigWarningKind::UnusedAgent && w.message.contains("orphaned")));
1767 }
1768
1769 #[test]
1770 fn test_no_warnings_for_fully_connected_config() {
1771 unsafe {
1773 std::env::set_var("TEST_JWT_SECRET", "test-secret-at-least-32-characters-long");
1774 std::env::set_var("TEST_API_KEY", "test-key");
1775 }
1776
1777 let content = r#"
1778[server]
1779[auth]
1780jwt_secret_env = "TEST_JWT_SECRET"
1781api_key_env = "TEST_API_KEY"
1782[database]
1783[providers.test]
1784type = "ollama"
1785default_model = "ministral-3:3b"
1786[models.default]
1787provider = "test"
1788model = "ministral-3:3b"
1789[tools.calc]
1790enabled = true
1791[agents.router]
1792model = "default"
1793tools = ["calc"]
1794[workflows.main]
1795entry_agent = "router"
1796"#;
1797
1798 let config: AresConfig = toml::from_str(content).unwrap();
1799 let warnings = config.validate_with_warnings().unwrap();
1800
1801 assert!(
1802 warnings.is_empty(),
1803 "Expected no warnings but got: {:?}",
1804 warnings
1805 );
1806 }
1807}