1const DEFAULT_CONFIG: &str = include_str!("../default.yaml");
11
12use std::{
13 collections::HashMap,
14 path::{Path, PathBuf},
15};
16
17use figment::{
18 providers::{Env, Format, Yaml},
19 Figment,
20};
21use serde::{Deserialize, Serialize};
22
23#[derive(Debug, Clone, Serialize, Deserialize)]
25pub struct BrainConfig {
26 pub brain: GeneralConfig,
27 pub storage: StorageConfig,
28 pub llm: LlmConfig,
29 pub embedding: EmbeddingConfig,
30 pub memory: MemoryConfig,
31 pub encryption: EncryptionConfig,
32 pub security: SecurityConfig,
33 pub actions: ActionsConfig,
34 pub proactivity: ProactivityConfig,
35 pub adapters: AdaptersConfig,
36 pub access: AccessConfig,
37}
38
39#[derive(Debug, Clone, Serialize, Deserialize)]
40pub struct GeneralConfig {
41 pub version: String,
42 pub data_dir: String,
43}
44
45#[derive(Debug, Clone, Serialize, Deserialize)]
46pub struct StorageConfig {
47 pub ruvector_path: String,
48 pub sqlite_path: String,
49 pub hnsw: HnswConfig,
50}
51
52#[derive(Debug, Clone, Serialize, Deserialize)]
53pub struct HnswConfig {
54 pub ef_construction: u32,
55 pub m: u32,
56 pub ef_search: u32,
57}
58
59#[derive(Debug, Clone, Serialize, Deserialize)]
60pub struct LlmConfig {
61 pub provider: String,
62 pub model: String,
63 pub base_url: String,
64 pub temperature: f64,
65 pub max_tokens: u32,
66 #[serde(default)]
69 pub api_key: String,
70 #[serde(default)]
73 pub intent_llm_fallback: bool,
74}
75
76#[derive(Debug, Clone, Serialize, Deserialize)]
77pub struct EmbeddingConfig {
78 pub model: String,
82 pub dimensions: u32,
85}
86
87#[derive(Debug, Clone, Serialize, Deserialize)]
88pub struct MemoryConfig {
89 pub episodic: EpisodicConfig,
90 pub semantic: SemanticConfig,
91 pub search: SearchConfig,
92 pub consolidation: ConsolidationConfig,
93}
94
95#[derive(Debug, Clone, Serialize, Deserialize)]
96pub struct EpisodicConfig {
97 pub max_entries: u64,
98 pub retention_days: u32,
99}
100
101#[derive(Debug, Clone, Serialize, Deserialize)]
102pub struct SemanticConfig {
103 pub similarity_threshold: f64,
104 pub max_results: u32,
105}
106
107#[derive(Debug, Clone, Serialize, Deserialize)]
108pub struct SearchConfig {
109 pub hybrid_weight: f64,
110 pub rrf_k: u32,
111 #[serde(default = "default_pre_fusion_limit")]
113 pub pre_fusion_limit: u32,
114 #[serde(default = "default_importance_weight")]
116 pub importance_weight: f64,
117 #[serde(default = "default_recency_weight")]
119 pub recency_weight: f64,
120 #[serde(default = "default_decay_rate")]
122 pub decay_rate: f64,
123}
124
125fn default_pre_fusion_limit() -> u32 {
126 50
127}
128fn default_importance_weight() -> f64 {
129 0.3
130}
131fn default_recency_weight() -> f64 {
132 0.2
133}
134fn default_decay_rate() -> f64 {
135 0.01
136}
137
138#[derive(Debug, Clone, Serialize, Deserialize)]
139pub struct ConsolidationConfig {
140 pub enabled: bool,
141 pub interval_hours: u32,
142 pub forgetting_threshold: f64,
143}
144
145#[derive(Debug, Clone, Serialize, Deserialize)]
146pub struct EncryptionConfig {
147 pub enabled: bool,
148}
149
150#[derive(Debug, Clone, Serialize, Deserialize)]
151pub struct SecurityConfig {
152 pub exec_allowlist: Vec<String>,
153 pub exec_timeout_seconds: u32,
154}
155
156#[derive(Debug, Clone, Serialize, Deserialize)]
157pub struct ActionsConfig {
158 pub web_search: WebSearchActionConfig,
159 pub scheduling: SchedulingActionConfig,
160 pub messaging: MessagingActionConfig,
161 #[serde(default)]
162 pub resilience: ResilienceConfig,
163}
164
165#[derive(Debug, Clone, Serialize, Deserialize)]
166pub struct ResilienceConfig {
167 pub max_retries: u32,
168 pub retry_base_ms: u64,
169 pub circuit_breaker_threshold: u32,
170 pub circuit_breaker_cooldown_secs: u64,
171}
172
173impl Default for ResilienceConfig {
174 fn default() -> Self {
175 Self {
176 max_retries: 2,
177 retry_base_ms: 500,
178 circuit_breaker_threshold: 5,
179 circuit_breaker_cooldown_secs: 60,
180 }
181 }
182}
183
184#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq)]
185#[serde(rename_all = "snake_case")]
186pub enum WebSearchProvider {
187 Searxng,
188 Tavily,
189 #[default]
190 Custom,
191}
192
193#[derive(Debug, Clone, Serialize, Deserialize)]
194pub struct WebSearchActionConfig {
195 pub enabled: bool,
196 #[serde(default)]
197 pub provider: WebSearchProvider,
198 pub endpoint: String,
199 #[serde(default)]
200 pub api_key: String,
201 pub timeout_ms: u64,
202 pub default_top_k: usize,
203}
204
205#[derive(Debug, Clone, Serialize, Deserialize)]
206pub struct SchedulingActionConfig {
207 pub enabled: bool,
208 pub mode: SchedulingMode,
209}
210
211#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
212#[serde(rename_all = "snake_case")]
213pub enum SchedulingMode {
214 PersistOnly,
215}
216
217#[derive(Debug, Clone, Serialize, Deserialize)]
218pub struct ChannelConfig {
219 pub url: String,
220 #[serde(default)]
221 pub body: String,
222 #[serde(default)]
223 pub headers: HashMap<String, String>,
224}
225
226#[derive(Debug, Clone, Serialize, Deserialize)]
227pub struct MessagingActionConfig {
228 pub enabled: bool,
229 pub timeout_ms: u64,
230 #[serde(deserialize_with = "deserialize_channels", default)]
231 pub channels: HashMap<String, ChannelConfig>,
232}
233
234fn deserialize_channels<'de, D>(deserializer: D) -> Result<HashMap<String, ChannelConfig>, D::Error>
236where
237 D: serde::Deserializer<'de>,
238{
239 #[derive(Deserialize)]
240 #[serde(untagged)]
241 enum ChannelEntry {
242 Full(ChannelConfig),
243 UrlOnly(String),
244 }
245
246 let raw: HashMap<String, ChannelEntry> = HashMap::deserialize(deserializer)?;
247 Ok(raw
248 .into_iter()
249 .map(|(k, v)| {
250 let config = match v {
251 ChannelEntry::Full(c) => c,
252 ChannelEntry::UrlOnly(url) => ChannelConfig {
253 url,
254 body: String::new(),
255 headers: HashMap::new(),
256 },
257 };
258 (k, config)
259 })
260 .collect())
261}
262
263#[derive(Debug, Clone, Serialize, Deserialize)]
264pub struct ProactivityConfig {
265 pub enabled: bool,
266 pub max_per_day: u32,
267 pub min_interval_minutes: u32,
268 pub quiet_hours: QuietHoursConfig,
269 #[serde(default)]
270 pub delivery: DeliveryConfig,
271 #[serde(default)]
272 pub open_loop: OpenLoopDetectionConfig,
273}
274
275#[derive(Debug, Clone, Serialize, Deserialize)]
277pub struct OpenLoopDetectionConfig {
278 pub enabled: bool,
280 pub scan_window_hours: u32,
282 pub resolution_window_hours: u32,
284 pub check_interval_minutes: u32,
286}
287
288impl Default for OpenLoopDetectionConfig {
289 fn default() -> Self {
290 Self {
291 enabled: true,
292 scan_window_hours: 72,
293 resolution_window_hours: 24,
294 check_interval_minutes: 120,
295 }
296 }
297}
298
299#[derive(Debug, Clone, Serialize, Deserialize)]
301pub struct DeliveryConfig {
302 pub outbox: bool,
304 pub broadcast: bool,
306 pub webhook_channels: Vec<String>,
308 pub max_outbox_age_days: u32,
310}
311
312impl Default for DeliveryConfig {
313 fn default() -> Self {
314 Self {
315 outbox: true,
316 broadcast: true,
317 webhook_channels: Vec::new(),
318 max_outbox_age_days: 7,
319 }
320 }
321}
322
323#[derive(Debug, Clone, Serialize, Deserialize)]
324pub struct QuietHoursConfig {
325 pub start: String,
326 pub end: String,
327 #[serde(default = "default_timezone")]
328 pub timezone: String,
329}
330
331fn default_timezone() -> String {
332 "UTC".to_string()
333}
334
335#[derive(Debug, Clone, Serialize, Deserialize)]
337pub struct ApiKeyConfig {
338 pub key: String,
340 pub name: String,
342 pub permissions: Vec<String>,
344}
345
346impl ApiKeyConfig {
347 pub fn has_permission(&self, perm: &str) -> bool {
349 self.permissions.iter().any(|p| p == perm)
350 }
351}
352
353#[derive(Debug, Clone, Serialize, Deserialize)]
355pub struct AccessConfig {
356 pub api_keys: Vec<ApiKeyConfig>,
357}
358
359impl AccessConfig {
360 pub fn validate(&self, key: &str, permission: &str) -> bool {
362 self.api_keys
363 .iter()
364 .any(|k| k.key == key && k.has_permission(permission))
365 }
366
367 pub fn find_key(&self, key: &str) -> Option<&ApiKeyConfig> {
369 self.api_keys.iter().find(|k| k.key == key)
370 }
371}
372
373#[derive(Debug, Clone, Serialize, Deserialize)]
374pub struct AdaptersConfig {
375 pub http: HttpAdapterConfig,
376 pub ws: WebSocketAdapterConfig,
377 pub mcp: McpAdapterConfig,
378 pub grpc: GrpcAdapterConfig,
379}
380
381#[derive(Debug, Clone, Serialize, Deserialize)]
382pub struct HttpAdapterConfig {
383 pub enabled: bool,
384 pub host: String,
385 pub port: u16,
386 pub cors: bool,
387}
388
389#[derive(Debug, Clone, Serialize, Deserialize)]
390pub struct WebSocketAdapterConfig {
391 pub enabled: bool,
392 pub port: u16,
393}
394
395#[derive(Debug, Clone, Serialize, Deserialize)]
396pub struct McpAdapterConfig {
397 pub enabled: bool,
398 pub stdio: bool,
399 pub http: bool,
400 pub port: u16,
401}
402
403#[derive(Debug, Clone, Serialize, Deserialize)]
404pub struct GrpcAdapterConfig {
405 pub enabled: bool,
406 pub port: u16,
407}
408
409impl BrainConfig {
410 #[allow(clippy::result_large_err)]
417 pub fn load() -> Result<Self, figment::Error> {
418 Self::load_from(None)
419 }
420
421 #[allow(clippy::result_large_err)]
423 pub fn load_from(config_path: Option<&Path>) -> Result<Self, figment::Error> {
424 let mut figment = Figment::new().merge(Yaml::string(DEFAULT_CONFIG));
426
427 let user_config = Self::user_config_path();
429 if user_config.exists() {
430 figment = figment.merge(Yaml::file(&user_config));
431 }
432
433 if let Some(path) = config_path {
435 figment = figment.merge(Yaml::file(path));
436 }
437
438 figment = figment.merge(Env::prefixed("BRAIN_").split("__"));
440
441 figment.extract()
442 }
443
444 pub fn data_dir(&self) -> PathBuf {
446 expand_tilde(&self.brain.data_dir)
447 }
448
449 pub fn ensure_data_dirs(&self) -> std::io::Result<()> {
451 let data_dir = self.data_dir();
452 let dirs = [
453 data_dir.clone(),
454 data_dir.join("db"), data_dir.join("ruvector"), data_dir.join("models"), data_dir.join("logs"), data_dir.join("exports"), ];
460
461 for dir in &dirs {
462 std::fs::create_dir_all(dir)?;
463 }
464
465 Ok(())
466 }
467
468 pub fn sqlite_path(&self) -> PathBuf {
470 self.data_dir().join("db").join("brain.db")
471 }
472
473 pub fn ruvector_path(&self) -> PathBuf {
475 self.data_dir().join("ruvector")
476 }
477
478 pub fn models_path(&self) -> PathBuf {
480 self.data_dir().join("models")
481 }
482
483 pub fn is_initialized() -> bool {
485 expand_tilde("~/.brain").exists()
486 }
487
488 pub fn write_default_config(force: bool) -> std::io::Result<Option<(PathBuf, String)>> {
497 let config_path = Self::user_config_path();
498
499 if config_path.exists() && !force {
500 return Ok(None);
501 }
502
503 if let Some(parent) = config_path.parent() {
505 std::fs::create_dir_all(parent)?;
506 }
507
508 let api_key = Self::generate_api_key();
510 let config = DEFAULT_CONFIG.replace("demokey123", &api_key);
511
512 std::fs::write(&config_path, config)?;
513 Ok(Some((config_path, api_key)))
514 }
515
516 fn generate_api_key() -> String {
518 let mut buf = [0u8; 16];
519 getrandom::getrandom(&mut buf).expect("failed to obtain random bytes from OS");
520 let hex: String = buf.iter().map(|b| format!("{:02x}", b)).collect();
521 format!("brk_{}", hex)
522 }
523
524 pub fn user_config_path() -> PathBuf {
526 expand_tilde("~/.brain/config.yaml")
527 }
528
529 pub fn default_config_content() -> &'static str {
531 DEFAULT_CONFIG
532 }
533
534 pub fn validate(&self) -> Result<Vec<String>, String> {
540 let mut warnings: Vec<String> = Vec::new();
541
542 let mut ports: std::collections::HashMap<u16, &str> = std::collections::HashMap::new();
544 let adapter_ports = [
545 (self.adapters.http.port, "http"),
546 (self.adapters.ws.port, "ws"),
547 (self.adapters.mcp.port, "mcp"),
548 (self.adapters.grpc.port, "grpc"),
549 ];
550 for (port, name) in &adapter_ports {
551 if let Some(existing) = ports.insert(*port, name) {
552 return Err(format!(
553 "Port conflict: adapters '{}' and '{}' both use port {}",
554 existing, name, port
555 ));
556 }
557 }
558
559 let url = &self.llm.base_url;
561 if !url.starts_with("http://") && !url.starts_with("https://") {
562 return Err(format!(
563 "Invalid LLM base_url '{}': must start with http:// or https://",
564 url
565 ));
566 }
567
568 let data_dir = self.data_dir();
570 if data_dir.exists() {
571 let probe = data_dir.join(".brain_write_probe");
573 if std::fs::write(&probe, b"").is_err() {
574 return Err(format!(
575 "Data directory '{}' is not writable",
576 data_dir.display()
577 ));
578 }
579 let _ = std::fs::remove_file(&probe);
580 }
581
582 if self.access.api_keys.is_empty() {
584 warnings.push("No API keys configured — all adapters will reject authenticated requests. Add at least one key under 'access.api_keys'.".to_string());
585 } else if self.access.api_keys.iter().any(|k| k.key == "demokey123") {
586 warnings.push("Demo API key 'demokey123' is still active. Replace it with a strong key in production.".to_string());
587 }
588
589 if self.llm.temperature > 1.5 {
590 warnings.push(format!(
591 "LLM temperature {:.1} is very high — responses may be unpredictable.",
592 self.llm.temperature
593 ));
594 }
595
596 if self.memory.consolidation.enabled && self.memory.consolidation.interval_hours == 0 {
597 warnings.push("Consolidation interval_hours is 0 — consolidation will run immediately on every daemon wake-up, which may impact performance.".to_string());
598 }
599
600 if self.actions.web_search.enabled {
601 match self.actions.web_search.provider {
602 WebSearchProvider::Custom if self.actions.web_search.endpoint.trim().is_empty() => {
603 warnings.push("Actions web_search provider is 'custom' but endpoint is empty; dispatches will fail with backend-not-configured.".to_string());
604 }
605 WebSearchProvider::Tavily if self.actions.web_search.api_key.trim().is_empty() => {
606 warnings.push("Actions web_search provider is 'tavily' but api_key is empty; dispatches will fail.".to_string());
607 }
608 _ => {}
609 }
610 }
611
612 if self.actions.messaging.enabled {
613 if self.actions.messaging.channels.is_empty() {
614 warnings.push("Actions messaging is enabled but actions.messaging.channels has no mappings; dispatches will fail for all channels.".to_string());
615 } else {
616 for (name, channel_cfg) in &self.actions.messaging.channels {
617 if channel_cfg.url.trim().is_empty() {
618 warnings.push(format!(
619 "actions.messaging.channels.{name}: url is empty; dispatches to this channel will fail."
620 ));
621 }
622 }
623 }
624 }
625
626 for (name, ms) in [
628 ("web_search.timeout_ms", self.actions.web_search.timeout_ms),
629 ("messaging.timeout_ms", self.actions.messaging.timeout_ms),
630 ] {
631 if ms == 0 {
632 warnings.push(format!(
633 "actions.{name} is 0; will be clamped to 1ms at runtime."
634 ));
635 } else if ms > 30_000 {
636 warnings.push(format!(
637 "actions.{name} is {}ms (>30s) — requests may block for a long time.",
638 ms
639 ));
640 }
641 }
642
643 let res = &self.actions.resilience;
645 if res.max_retries > 10 {
646 warnings.push(format!("actions.resilience.max_retries is {} (>10) — excessive retries may amplify failures.", res.max_retries));
647 }
648 if res.circuit_breaker_threshold == 0 {
649 warnings.push("actions.resilience.circuit_breaker_threshold is 0; circuit breaker will never trip.".to_string());
650 }
651
652 Ok(warnings)
653 }
654}
655
656impl Default for BrainConfig {
657 fn default() -> Self {
658 Self {
659 brain: GeneralConfig {
660 version: env!("CARGO_PKG_VERSION").to_string(),
661 data_dir: "~/.brain".to_string(),
662 },
663 storage: StorageConfig {
664 ruvector_path: "~/.brain/ruvector/".to_string(),
665 sqlite_path: "~/.brain/db/brain.db".to_string(),
666 hnsw: HnswConfig {
667 ef_construction: 200,
668 m: 16,
669 ef_search: 50,
670 },
671 },
672 llm: LlmConfig {
673 provider: "ollama".to_string(),
674 model: "qwen2.5-coder:7b".to_string(),
675 base_url: "http://localhost:11434".to_string(),
676 temperature: 0.7,
677 max_tokens: 4096,
678 api_key: String::new(),
679 intent_llm_fallback: false,
680 },
681 embedding: EmbeddingConfig {
682 model: "nomic-embed-text".to_string(),
683 dimensions: 768,
684 },
685 memory: MemoryConfig {
686 episodic: EpisodicConfig {
687 max_entries: 100_000,
688 retention_days: 365,
689 },
690 semantic: SemanticConfig {
691 similarity_threshold: 0.65,
692 max_results: 20,
693 },
694 search: SearchConfig {
695 hybrid_weight: 0.7,
696 rrf_k: 60,
697 pre_fusion_limit: 50,
698 importance_weight: 0.3,
699 recency_weight: 0.2,
700 decay_rate: 0.01,
701 },
702 consolidation: ConsolidationConfig {
703 enabled: true,
704 interval_hours: 24,
705 forgetting_threshold: 0.05,
706 },
707 },
708 encryption: EncryptionConfig { enabled: false }, security: SecurityConfig {
710 exec_allowlist: vec![
711 "ls".into(),
712 "cat".into(),
713 "grep".into(),
714 "find".into(),
715 "git".into(),
716 "cargo".into(),
717 "rustc".into(),
718 ],
719 exec_timeout_seconds: 30,
720 },
721 actions: ActionsConfig {
722 web_search: WebSearchActionConfig {
723 enabled: true,
724 provider: WebSearchProvider::Searxng,
725 endpoint: "http://localhost:8888".to_string(),
726 api_key: String::new(),
727 timeout_ms: 3_000,
728 default_top_k: 5,
729 },
730 scheduling: SchedulingActionConfig {
731 enabled: false,
732 mode: SchedulingMode::PersistOnly,
733 },
734 messaging: MessagingActionConfig {
735 enabled: false,
736 timeout_ms: 3_000,
737 channels: HashMap::new(),
738 },
739 resilience: ResilienceConfig::default(),
740 },
741 proactivity: ProactivityConfig {
742 enabled: false,
743 max_per_day: 5,
744 min_interval_minutes: 60,
745 quiet_hours: QuietHoursConfig {
746 start: "22:00".to_string(),
747 end: "08:00".to_string(),
748 timezone: "UTC".to_string(),
749 },
750 delivery: DeliveryConfig::default(),
751 open_loop: OpenLoopDetectionConfig::default(),
752 },
753 adapters: AdaptersConfig {
754 http: HttpAdapterConfig {
755 enabled: true,
756 host: "127.0.0.1".to_string(),
757 port: 19789,
758 cors: true,
759 },
760 ws: WebSocketAdapterConfig {
761 enabled: true,
762 port: 19790,
763 },
764 mcp: McpAdapterConfig {
765 enabled: true,
766 stdio: true,
767 http: true,
768 port: 19791,
769 },
770 grpc: GrpcAdapterConfig {
771 enabled: true,
772 port: 19792,
773 },
774 },
775 access: AccessConfig {
776 api_keys: vec![ApiKeyConfig {
777 key: "demokey123".to_string(),
778 name: "Demo Key".to_string(),
779 permissions: vec!["read".to_string(), "write".to_string()],
780 }],
781 },
782 }
783 }
784}
785
786fn expand_tilde(path: &str) -> PathBuf {
788 if let Some(rest) = path.strip_prefix("~/") {
789 if let Some(home) = dirs_home() {
790 return home.join(rest);
791 }
792 }
793 PathBuf::from(path)
794}
795
796fn dirs_home() -> Option<PathBuf> {
798 std::env::var_os("HOME").map(PathBuf::from)
799}
800
801#[cfg(test)]
802mod tests {
803 use super::*;
804
805 #[test]
806 fn test_default_config() {
807 let config = BrainConfig::default();
808 assert_eq!(config.brain.data_dir, "~/.brain");
809 assert_eq!(config.llm.provider, "ollama");
810 assert_eq!(config.embedding.dimensions, 768); assert!(!config.encryption.enabled); assert_eq!(
813 config.actions.web_search.provider,
814 WebSearchProvider::Searxng
815 );
816 assert_eq!(config.actions.scheduling.mode, SchedulingMode::PersistOnly);
817 assert!(!config.proactivity.enabled);
818 assert!(config.adapters.http.enabled);
819 }
820
821 #[test]
822 fn test_expand_tilde() {
823 let expanded = expand_tilde("~/.brain");
824 assert!(!expanded.to_str().unwrap().starts_with('~'));
825 assert!(expanded.to_str().unwrap().ends_with(".brain"));
826 }
827
828 #[test]
829 fn test_data_dir_paths() {
830 let config = BrainConfig::default();
831 let data = config.data_dir();
832 assert!(data.to_str().unwrap().ends_with(".brain"));
833 assert!(config.sqlite_path().to_str().unwrap().ends_with("brain.db"));
834 assert!(config
835 .ruvector_path()
836 .to_str()
837 .unwrap()
838 .ends_with("ruvector"));
839 }
840
841 #[test]
842 fn test_load_from_defaults() {
843 use figment::providers::Serialized;
844 let figment = Figment::new().merge(Serialized::defaults(BrainConfig::default()));
846 let config: BrainConfig = figment.extract().unwrap();
847 assert_eq!(config.llm.model, "qwen2.5-coder:7b");
848 assert_eq!(config.memory.search.rrf_k, 60);
849 assert_eq!(config.memory.search.pre_fusion_limit, 50);
850 assert!((config.memory.search.importance_weight - 0.3).abs() < f64::EPSILON);
851 assert!((config.memory.search.recency_weight - 0.2).abs() < f64::EPSILON);
852 assert!((config.memory.search.decay_rate - 0.01).abs() < f64::EPSILON);
853 }
854
855 fn writable_test_data_dir() -> String {
859 std::env::temp_dir()
860 .join("brain-core-tests")
861 .to_string_lossy()
862 .to_string()
863 }
864
865 fn validated_config() -> BrainConfig {
867 let mut c = BrainConfig::default();
868 c.brain.data_dir = writable_test_data_dir();
869 c.access.api_keys.clear();
870 c
871 }
872
873 #[test]
874 fn test_validate_default_has_demo_key_warning() {
875 let mut config = BrainConfig::default();
876 config.brain.data_dir = writable_test_data_dir();
877 let warnings = config.validate().expect("default config should be valid");
878 assert!(
879 warnings.iter().any(|w| w.contains("demokey123")),
880 "expected demo-key warning, got: {:?}",
881 warnings
882 );
883 }
884
885 #[test]
886 fn test_validate_no_api_keys_warning() {
887 let config = validated_config();
888 let warnings = config.validate().expect("should be valid");
889 assert!(
890 warnings.iter().any(|w| w.contains("No API keys")),
891 "expected no-api-keys warning, got: {:?}",
892 warnings
893 );
894 }
895
896 #[test]
897 fn test_validate_port_conflict_is_hard_error() {
898 let mut config = validated_config();
899 config.adapters.ws.port = config.adapters.http.port;
901 let err = config
902 .validate()
903 .expect_err("should fail with port conflict");
904 assert!(
905 err.contains("Port conflict"),
906 "unexpected error message: {err}"
907 );
908 }
909
910 #[test]
911 fn test_validate_bad_llm_url_is_hard_error() {
912 let mut config = validated_config();
913 config.llm.base_url = "ftp://invalid.example.com".to_string();
914 let err = config.validate().expect_err("should fail with bad URL");
915 assert!(
916 err.contains("Invalid LLM base_url"),
917 "unexpected error: {err}"
918 );
919 }
920
921 #[test]
922 fn test_validate_high_temperature_warning() {
923 let mut config = validated_config();
924 config.llm.temperature = 2.0;
925 let warnings = config.validate().expect("should be valid");
926 assert!(
927 warnings.iter().any(|w| w.contains("temperature")),
928 "expected temperature warning, got: {:?}",
929 warnings
930 );
931 }
932
933 #[test]
934 fn test_validate_consolidation_interval_zero_warning() {
935 let mut config = validated_config();
936 config.memory.consolidation.enabled = true;
937 config.memory.consolidation.interval_hours = 0;
938 let warnings = config.validate().expect("should be valid");
939 assert!(
940 warnings.iter().any(|w| w.contains("interval_hours")),
941 "expected interval warning, got: {:?}",
942 warnings
943 );
944 }
945
946 #[test]
947 fn test_actions_defaults_deserialize() {
948 let config = BrainConfig::load().expect("embedded defaults should load");
949 assert!(config.actions.web_search.enabled);
950 assert_eq!(
951 config.actions.web_search.provider,
952 WebSearchProvider::Searxng
953 );
954 assert_eq!(config.actions.web_search.default_top_k, 5);
955 assert_eq!(config.actions.scheduling.mode, SchedulingMode::PersistOnly);
956 assert!(!config.actions.messaging.enabled);
957 }
958
959 #[test]
960 fn test_validate_actions_warning_custom_without_endpoint() {
961 let mut config = validated_config();
962 config.actions.web_search.enabled = true;
963 config.actions.web_search.provider = WebSearchProvider::Custom;
964 config.actions.web_search.endpoint.clear();
965 config.actions.messaging.enabled = true;
966 config.actions.messaging.channels.clear();
967 let warnings = config.validate().expect("config should still be valid");
968 assert!(warnings.iter().any(|w| w.contains("'custom'")));
969 assert!(warnings.iter().any(|w| w.contains("messaging")));
970 }
971
972 #[test]
973 fn test_validate_tavily_without_api_key_warning() {
974 let mut config = validated_config();
975 config.actions.web_search.enabled = true;
976 config.actions.web_search.provider = WebSearchProvider::Tavily;
977 config.actions.web_search.api_key.clear();
978 let warnings = config.validate().expect("config should still be valid");
979 assert!(
980 warnings
981 .iter()
982 .any(|w| w.contains("'tavily'") && w.contains("api_key")),
983 "expected tavily api_key warning, got: {:?}",
984 warnings
985 );
986 }
987
988 #[test]
989 fn test_validate_searxng_no_web_search_warning() {
990 let mut config = validated_config();
991 config.actions.web_search.enabled = true;
992 config.actions.web_search.provider = WebSearchProvider::Searxng;
993 let warnings = config.validate().expect("config should still be valid");
994 assert!(
995 !warnings.iter().any(|w| w.contains("web_search")),
996 "SearXNG with default endpoint should not trigger web_search warning, got: {:?}",
997 warnings
998 );
999 }
1000
1001 #[test]
1002 fn test_validate_http_and_https_urls_accepted() {
1003 let mut config = validated_config();
1004 config.llm.base_url = "https://api.example.com/v1".to_string();
1005 assert!(config.validate().is_ok());
1006
1007 config.llm.base_url = "http://localhost:11434".to_string();
1008 assert!(config.validate().is_ok());
1009 }
1010
1011 #[test]
1012 fn test_validate_all_unique_ports_ok() {
1013 let config = validated_config();
1014 assert!(config.validate().is_ok());
1016 }
1017
1018 #[test]
1019 fn test_validate_timeout_zero_warning() {
1020 let mut config = validated_config();
1021 config.actions.web_search.timeout_ms = 0;
1022 let warnings = config.validate().expect("should be valid");
1023 assert!(
1024 warnings
1025 .iter()
1026 .any(|w| w.contains("timeout_ms") && w.contains("0")),
1027 "expected timeout_ms=0 warning, got: {:?}",
1028 warnings
1029 );
1030 }
1031
1032 #[test]
1033 fn test_validate_timeout_too_high_warning() {
1034 let mut config = validated_config();
1035 config.actions.messaging.timeout_ms = 60_000;
1036 let warnings = config.validate().expect("should be valid");
1037 assert!(
1038 warnings
1039 .iter()
1040 .any(|w| w.contains("timeout_ms") && w.contains("60000")),
1041 "expected high timeout warning, got: {:?}",
1042 warnings
1043 );
1044 }
1045
1046 #[test]
1047 fn test_validate_resilience_max_retries_warning() {
1048 let mut config = validated_config();
1049 config.actions.resilience.max_retries = 15;
1050 let warnings = config.validate().expect("should be valid");
1051 assert!(
1052 warnings
1053 .iter()
1054 .any(|w| w.contains("max_retries") && w.contains("15")),
1055 "expected max_retries warning, got: {:?}",
1056 warnings
1057 );
1058 }
1059
1060 #[test]
1061 fn test_validate_resilience_threshold_zero_warning() {
1062 let mut config = validated_config();
1063 config.actions.resilience.circuit_breaker_threshold = 0;
1064 let warnings = config.validate().expect("should be valid");
1065 assert!(
1066 warnings
1067 .iter()
1068 .any(|w| w.contains("circuit_breaker_threshold")),
1069 "expected circuit_breaker_threshold=0 warning, got: {:?}",
1070 warnings
1071 );
1072 }
1073
1074 #[test]
1075 fn test_resilience_defaults() {
1076 let res = ResilienceConfig::default();
1077 assert_eq!(res.max_retries, 2);
1078 assert_eq!(res.retry_base_ms, 500);
1079 assert_eq!(res.circuit_breaker_threshold, 5);
1080 assert_eq!(res.circuit_breaker_cooldown_secs, 60);
1081 }
1082
1083 #[test]
1084 fn test_channel_config_old_format_compat() {
1085 let yaml = r#"
1087 enabled: false
1088 timeout_ms: 3000
1089 channels:
1090 alerts: "https://example.com/hook"
1091 ops: "https://slack.example.com/webhook"
1092 "#;
1093 let cfg: MessagingActionConfig =
1094 serde_yaml::from_str(yaml).expect("old format should deserialize");
1095 assert_eq!(cfg.channels.len(), 2);
1096 assert_eq!(cfg.channels["alerts"].url, "https://example.com/hook");
1097 assert!(cfg.channels["alerts"].body.is_empty());
1098 assert!(cfg.channels["alerts"].headers.is_empty());
1099 }
1100
1101 #[test]
1102 fn test_channel_config_new_format() {
1103 let yaml = r#"
1104 enabled: true
1105 timeout_ms: 3000
1106 channels:
1107 alerts:
1108 url: "https://hooks.slack.com/services/T/B/x"
1109 body: '{"text": "{{content}}"}'
1110 headers:
1111 Authorization: "Bearer tok123"
1112 "#;
1113 let cfg: MessagingActionConfig =
1114 serde_yaml::from_str(yaml).expect("new format should deserialize");
1115 assert_eq!(cfg.channels.len(), 1);
1116 let ch = &cfg.channels["alerts"];
1117 assert_eq!(ch.url, "https://hooks.slack.com/services/T/B/x");
1118 assert_eq!(ch.body, r#"{"text": "{{content}}"}"#);
1119 assert_eq!(ch.headers["Authorization"], "Bearer tok123");
1120 }
1121
1122 #[test]
1123 fn test_channel_config_mixed_format() {
1124 let yaml = r#"
1125 enabled: true
1126 timeout_ms: 3000
1127 channels:
1128 simple: "https://example.com/hook"
1129 custom:
1130 url: "https://discord.com/api/webhooks/123/abc"
1131 body: '{"content": "{{content}}"}'
1132 "#;
1133 let cfg: MessagingActionConfig =
1134 serde_yaml::from_str(yaml).expect("mixed format should deserialize");
1135 assert_eq!(cfg.channels.len(), 2);
1136 assert_eq!(cfg.channels["simple"].url, "https://example.com/hook");
1137 assert!(cfg.channels["simple"].body.is_empty());
1138 let custom = &cfg.channels["custom"];
1139 assert_eq!(custom.url, "https://discord.com/api/webhooks/123/abc");
1140 assert!(!custom.body.is_empty());
1141 assert!(custom.headers.is_empty());
1142 }
1143
1144 #[test]
1145 fn test_validate_channel_empty_url_warning() {
1146 let mut config = validated_config();
1147 config.actions.messaging.enabled = true;
1148 config.actions.messaging.channels.insert(
1149 "bad".into(),
1150 ChannelConfig {
1151 url: "".into(),
1152 body: String::new(),
1153 headers: HashMap::new(),
1154 },
1155 );
1156 let warnings = config.validate().expect("should be valid");
1157 assert!(
1158 warnings
1159 .iter()
1160 .any(|w| w.contains("channels.bad") && w.contains("url is empty")),
1161 "expected empty-url warning, got: {:?}",
1162 warnings
1163 );
1164 }
1165}