1const DEFAULT_CONFIG: &str = include_str!("../default.yaml");
11
12use std::{
13 collections::HashMap,
14 path::{Path, PathBuf},
15};
16
17use figment::{
18 providers::{Env, Format, Yaml},
19 Figment,
20};
21use serde::{Deserialize, Serialize};
22
23#[derive(Debug, Clone, Serialize, Deserialize)]
25pub struct BrainConfig {
26 pub brain: GeneralConfig,
27 pub storage: StorageConfig,
28 pub llm: LlmConfig,
29 pub embedding: EmbeddingConfig,
30 pub memory: MemoryConfig,
31 pub encryption: EncryptionConfig,
32 pub security: SecurityConfig,
33 pub actions: ActionsConfig,
34 pub proactivity: ProactivityConfig,
35 pub adapters: AdaptersConfig,
36 pub access: AccessConfig,
37}
38
39#[derive(Debug, Clone, Serialize, Deserialize)]
40pub struct GeneralConfig {
41 pub version: String,
42 pub data_dir: String,
43}
44
45#[derive(Debug, Clone, Serialize, Deserialize)]
46pub struct StorageConfig {
47 pub ruvector_path: String,
48 pub sqlite_path: String,
49 pub hnsw: HnswConfig,
50}
51
52#[derive(Debug, Clone, Serialize, Deserialize)]
53pub struct HnswConfig {
54 pub ef_construction: u32,
55 pub m: u32,
56 pub ef_search: u32,
57}
58
59#[derive(Debug, Clone, Serialize, Deserialize)]
60pub struct LlmConfig {
61 pub provider: String,
62 pub model: String,
63 pub base_url: String,
64 pub temperature: f64,
65 pub max_tokens: u32,
66 #[serde(default)]
69 pub api_key: String,
70}
71
72#[derive(Debug, Clone, Serialize, Deserialize)]
73pub struct EmbeddingConfig {
74 pub model: String,
78 pub dimensions: u32,
81}
82
83#[derive(Debug, Clone, Serialize, Deserialize)]
84pub struct MemoryConfig {
85 pub episodic: EpisodicConfig,
86 pub semantic: SemanticConfig,
87 pub search: SearchConfig,
88 pub consolidation: ConsolidationConfig,
89}
90
91#[derive(Debug, Clone, Serialize, Deserialize)]
92pub struct EpisodicConfig {
93 pub max_entries: u64,
94 pub retention_days: u32,
95}
96
97#[derive(Debug, Clone, Serialize, Deserialize)]
98pub struct SemanticConfig {
99 pub similarity_threshold: f64,
100 pub max_results: u32,
101}
102
103#[derive(Debug, Clone, Serialize, Deserialize)]
104pub struct SearchConfig {
105 pub hybrid_weight: f64,
106 pub rrf_k: u32,
107 #[serde(default = "default_pre_fusion_limit")]
109 pub pre_fusion_limit: u32,
110 #[serde(default = "default_importance_weight")]
112 pub importance_weight: f64,
113 #[serde(default = "default_recency_weight")]
115 pub recency_weight: f64,
116 #[serde(default = "default_decay_rate")]
118 pub decay_rate: f64,
119}
120
121fn default_pre_fusion_limit() -> u32 {
122 50
123}
124fn default_importance_weight() -> f64 {
125 0.3
126}
127fn default_recency_weight() -> f64 {
128 0.2
129}
130fn default_decay_rate() -> f64 {
131 0.01
132}
133
134#[derive(Debug, Clone, Serialize, Deserialize)]
135pub struct ConsolidationConfig {
136 pub enabled: bool,
137 pub interval_hours: u32,
138 pub forgetting_threshold: f64,
139}
140
141#[derive(Debug, Clone, Serialize, Deserialize)]
142pub struct EncryptionConfig {
143 pub enabled: bool,
144}
145
146#[derive(Debug, Clone, Serialize, Deserialize)]
147pub struct SecurityConfig {
148 pub exec_allowlist: Vec<String>,
149 pub exec_timeout_seconds: u32,
150}
151
152#[derive(Debug, Clone, Serialize, Deserialize)]
153pub struct ActionsConfig {
154 pub web_search: WebSearchActionConfig,
155 pub scheduling: SchedulingActionConfig,
156 pub messaging: MessagingActionConfig,
157 #[serde(default)]
158 pub resilience: ResilienceConfig,
159}
160
161#[derive(Debug, Clone, Serialize, Deserialize)]
162pub struct ResilienceConfig {
163 pub max_retries: u32,
164 pub retry_base_ms: u64,
165 pub circuit_breaker_threshold: u32,
166 pub circuit_breaker_cooldown_secs: u64,
167}
168
169impl Default for ResilienceConfig {
170 fn default() -> Self {
171 Self {
172 max_retries: 2,
173 retry_base_ms: 500,
174 circuit_breaker_threshold: 5,
175 circuit_breaker_cooldown_secs: 60,
176 }
177 }
178}
179
180#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq)]
181#[serde(rename_all = "snake_case")]
182pub enum WebSearchProvider {
183 Searxng,
184 Tavily,
185 #[default]
186 Custom,
187}
188
189#[derive(Debug, Clone, Serialize, Deserialize)]
190pub struct WebSearchActionConfig {
191 pub enabled: bool,
192 #[serde(default)]
193 pub provider: WebSearchProvider,
194 pub endpoint: String,
195 #[serde(default)]
196 pub api_key: String,
197 pub timeout_ms: u64,
198 pub default_top_k: usize,
199}
200
201#[derive(Debug, Clone, Serialize, Deserialize)]
202pub struct SchedulingActionConfig {
203 pub enabled: bool,
204 pub mode: SchedulingMode,
205}
206
207#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
208#[serde(rename_all = "snake_case")]
209pub enum SchedulingMode {
210 PersistOnly,
211}
212
213#[derive(Debug, Clone, Serialize, Deserialize)]
214pub struct ChannelConfig {
215 pub url: String,
216 #[serde(default)]
217 pub body: String,
218 #[serde(default)]
219 pub headers: HashMap<String, String>,
220}
221
222#[derive(Debug, Clone, Serialize, Deserialize)]
223pub struct MessagingActionConfig {
224 pub enabled: bool,
225 pub timeout_ms: u64,
226 #[serde(deserialize_with = "deserialize_channels", default)]
227 pub channels: HashMap<String, ChannelConfig>,
228}
229
230fn deserialize_channels<'de, D>(deserializer: D) -> Result<HashMap<String, ChannelConfig>, D::Error>
232where
233 D: serde::Deserializer<'de>,
234{
235 #[derive(Deserialize)]
236 #[serde(untagged)]
237 enum ChannelEntry {
238 Full(ChannelConfig),
239 UrlOnly(String),
240 }
241
242 let raw: HashMap<String, ChannelEntry> = HashMap::deserialize(deserializer)?;
243 Ok(raw
244 .into_iter()
245 .map(|(k, v)| {
246 let config = match v {
247 ChannelEntry::Full(c) => c,
248 ChannelEntry::UrlOnly(url) => ChannelConfig {
249 url,
250 body: String::new(),
251 headers: HashMap::new(),
252 },
253 };
254 (k, config)
255 })
256 .collect())
257}
258
259#[derive(Debug, Clone, Serialize, Deserialize)]
260pub struct ProactivityConfig {
261 pub enabled: bool,
262 pub max_per_day: u32,
263 pub min_interval_minutes: u32,
264 pub quiet_hours: QuietHoursConfig,
265 #[serde(default)]
266 pub delivery: DeliveryConfig,
267 #[serde(default)]
268 pub open_loop: OpenLoopDetectionConfig,
269}
270
271#[derive(Debug, Clone, Serialize, Deserialize)]
273pub struct OpenLoopDetectionConfig {
274 pub enabled: bool,
276 pub scan_window_hours: u32,
278 pub resolution_window_hours: u32,
280 pub check_interval_minutes: u32,
282}
283
284impl Default for OpenLoopDetectionConfig {
285 fn default() -> Self {
286 Self {
287 enabled: true,
288 scan_window_hours: 72,
289 resolution_window_hours: 24,
290 check_interval_minutes: 120,
291 }
292 }
293}
294
295#[derive(Debug, Clone, Serialize, Deserialize)]
297pub struct DeliveryConfig {
298 pub outbox: bool,
300 pub broadcast: bool,
302 pub webhook_channels: Vec<String>,
304 pub max_outbox_age_days: u32,
306}
307
308impl Default for DeliveryConfig {
309 fn default() -> Self {
310 Self {
311 outbox: true,
312 broadcast: true,
313 webhook_channels: Vec::new(),
314 max_outbox_age_days: 7,
315 }
316 }
317}
318
319#[derive(Debug, Clone, Serialize, Deserialize)]
320pub struct QuietHoursConfig {
321 pub start: String,
322 pub end: String,
323 #[serde(default = "default_timezone")]
324 pub timezone: String,
325}
326
327fn default_timezone() -> String {
328 "UTC".to_string()
329}
330
331#[derive(Debug, Clone, Serialize, Deserialize)]
333pub struct ApiKeyConfig {
334 pub key: String,
336 pub name: String,
338 pub permissions: Vec<String>,
340}
341
342impl ApiKeyConfig {
343 pub fn has_permission(&self, perm: &str) -> bool {
345 self.permissions.iter().any(|p| p == perm)
346 }
347}
348
349#[derive(Debug, Clone, Serialize, Deserialize)]
351pub struct AccessConfig {
352 pub api_keys: Vec<ApiKeyConfig>,
353}
354
355impl AccessConfig {
356 pub fn find_key(&self, key: &str) -> Option<&ApiKeyConfig> {
358 self.api_keys.iter().find(|k| k.key == key)
359 }
360}
361
362#[derive(Debug, Clone, Serialize, Deserialize)]
363pub struct AdaptersConfig {
364 pub http: HttpAdapterConfig,
365 pub ws: WebSocketAdapterConfig,
366 pub mcp: McpAdapterConfig,
367 pub grpc: GrpcAdapterConfig,
368}
369
370#[derive(Debug, Clone, Serialize, Deserialize)]
371pub struct HttpAdapterConfig {
372 pub enabled: bool,
373 pub host: String,
374 pub port: u16,
375 pub cors: bool,
376}
377
378#[derive(Debug, Clone, Serialize, Deserialize)]
379pub struct WebSocketAdapterConfig {
380 pub enabled: bool,
381 pub port: u16,
382}
383
384#[derive(Debug, Clone, Serialize, Deserialize)]
385pub struct McpAdapterConfig {
386 pub enabled: bool,
387 pub stdio: bool,
388 pub http: bool,
389 pub port: u16,
390}
391
392#[derive(Debug, Clone, Serialize, Deserialize)]
393pub struct GrpcAdapterConfig {
394 pub enabled: bool,
395 pub port: u16,
396}
397
398impl BrainConfig {
399 #[allow(clippy::result_large_err)]
406 pub fn load() -> Result<Self, figment::Error> {
407 Self::load_from(None)
408 }
409
410 #[allow(clippy::result_large_err)]
412 pub fn load_from(config_path: Option<&Path>) -> Result<Self, figment::Error> {
413 let mut figment = Figment::new().merge(Yaml::string(DEFAULT_CONFIG));
415
416 let user_config = Self::user_config_path();
418 if user_config.exists() {
419 figment = figment.merge(Yaml::file(&user_config));
420 }
421
422 if let Some(path) = config_path {
424 figment = figment.merge(Yaml::file(path));
425 }
426
427 figment = figment.merge(Env::prefixed("BRAIN_").split("__"));
429
430 figment.extract()
431 }
432
433 pub fn data_dir(&self) -> PathBuf {
435 expand_tilde(&self.brain.data_dir)
436 }
437
438 pub fn ensure_data_dirs(&self) -> std::io::Result<()> {
440 let data_dir = self.data_dir();
441 let dirs = [
442 data_dir.clone(),
443 data_dir.join("db"), data_dir.join("ruvector"), data_dir.join("models"), data_dir.join("logs"), data_dir.join("exports"), ];
449
450 for dir in &dirs {
451 std::fs::create_dir_all(dir)?;
452 }
453
454 Ok(())
455 }
456
457 pub fn sqlite_path(&self) -> PathBuf {
459 self.data_dir().join("db").join("brain.db")
460 }
461
462 pub fn ruvector_path(&self) -> PathBuf {
464 self.data_dir().join("ruvector")
465 }
466
467 pub fn models_path(&self) -> PathBuf {
469 self.data_dir().join("models")
470 }
471
472 pub fn is_initialized() -> bool {
474 expand_tilde("~/.brain").exists()
475 }
476
477 pub fn write_default_config(force: bool) -> std::io::Result<Option<(PathBuf, String)>> {
486 let config_path = Self::user_config_path();
487
488 if config_path.exists() && !force {
489 return Ok(None);
490 }
491
492 if let Some(parent) = config_path.parent() {
494 std::fs::create_dir_all(parent)?;
495 }
496
497 let api_key = Self::generate_api_key();
499 let config = DEFAULT_CONFIG.replace(
500 "api_keys: []",
501 &format!(
502 "api_keys:\n - key: \"{}\"\n name: \"Default Key\"\n permissions: [read, write]",
503 api_key
504 ),
505 );
506
507 std::fs::write(&config_path, config)?;
508 Ok(Some((config_path, api_key)))
509 }
510
511 fn generate_api_key() -> String {
513 let mut buf = [0u8; 16];
514 getrandom::getrandom(&mut buf).expect("failed to obtain random bytes from OS");
515 let hex: String = buf.iter().map(|b| format!("{:02x}", b)).collect();
516 format!("brk_{}", hex)
517 }
518
519 pub fn user_config_path() -> PathBuf {
521 expand_tilde("~/.brain/config.yaml")
522 }
523
524 pub fn default_config_content() -> &'static str {
526 DEFAULT_CONFIG
527 }
528
529 pub fn validate(&self) -> Result<Vec<String>, String> {
535 let mut warnings: Vec<String> = Vec::new();
536
537 let mut ports: std::collections::HashMap<u16, &str> = std::collections::HashMap::new();
539 let adapter_ports = [
540 (self.adapters.http.port, "http"),
541 (self.adapters.ws.port, "ws"),
542 (self.adapters.mcp.port, "mcp"),
543 (self.adapters.grpc.port, "grpc"),
544 ];
545 for (port, name) in &adapter_ports {
546 if let Some(existing) = ports.insert(*port, name) {
547 return Err(format!(
548 "Port conflict: adapters '{}' and '{}' both use port {}",
549 existing, name, port
550 ));
551 }
552 }
553
554 let url = &self.llm.base_url;
556 if !url.starts_with("http://") && !url.starts_with("https://") {
557 return Err(format!(
558 "Invalid LLM base_url '{}': must start with http:// or https://",
559 url
560 ));
561 }
562
563 let data_dir = self.data_dir();
565 if data_dir.exists() {
566 let probe = data_dir.join(".brain_write_probe");
568 if std::fs::write(&probe, b"").is_err() {
569 return Err(format!(
570 "Data directory '{}' is not writable",
571 data_dir.display()
572 ));
573 }
574 let _ = std::fs::remove_file(&probe);
575 }
576
577 if self.access.api_keys.is_empty() {
579 warnings.push("No API keys configured — all adapters will reject authenticated requests. Run `brain init` or add a key under 'access.api_keys'.".to_string());
580 }
581
582 if self.llm.temperature > 1.5 {
583 warnings.push(format!(
584 "LLM temperature {:.1} is very high — responses may be unpredictable.",
585 self.llm.temperature
586 ));
587 }
588
589 if self.memory.consolidation.enabled && self.memory.consolidation.interval_hours == 0 {
590 warnings.push("Consolidation interval_hours is 0 — consolidation will run immediately on every daemon wake-up, which may impact performance.".to_string());
591 }
592
593 if self.actions.web_search.enabled {
594 match self.actions.web_search.provider {
595 WebSearchProvider::Custom if self.actions.web_search.endpoint.trim().is_empty() => {
596 warnings.push("Actions web_search provider is 'custom' but endpoint is empty; dispatches will fail with backend-not-configured.".to_string());
597 }
598 WebSearchProvider::Tavily if self.actions.web_search.api_key.trim().is_empty() => {
599 warnings.push("Actions web_search provider is 'tavily' but api_key is empty; dispatches will fail.".to_string());
600 }
601 _ => {}
602 }
603 }
604
605 if self.actions.messaging.enabled {
606 if self.actions.messaging.channels.is_empty() {
607 warnings.push("Actions messaging is enabled but actions.messaging.channels has no mappings; dispatches will fail for all channels.".to_string());
608 } else {
609 for (name, channel_cfg) in &self.actions.messaging.channels {
610 if channel_cfg.url.trim().is_empty() {
611 warnings.push(format!(
612 "actions.messaging.channels.{name}: url is empty; dispatches to this channel will fail."
613 ));
614 }
615 }
616 }
617 }
618
619 #[allow(clippy::float_cmp)]
621 if self.memory.search.hybrid_weight != 0.7 {
622 warnings.push(
623 "memory.search.hybrid_weight is set but unused — recall uses Reciprocal Rank Fusion (rrf_k) instead. This field will be removed in a future release.".to_string()
624 );
625 }
626 if self.memory.episodic.max_entries != 100_000 {
627 warnings.push(
628 "memory.episodic.max_entries is set but not enforced — no pruning logic exists yet. This field is reserved for future use.".to_string()
629 );
630 }
631 if self.memory.episodic.retention_days != 365 {
632 warnings.push(
633 "memory.episodic.retention_days is set but not enforced — recall uses a forgetting curve (decay_rate) instead of TTL-based retention.".to_string()
634 );
635 }
636
637 for (name, ms) in [
639 ("web_search.timeout_ms", self.actions.web_search.timeout_ms),
640 ("messaging.timeout_ms", self.actions.messaging.timeout_ms),
641 ] {
642 if ms == 0 {
643 warnings.push(format!(
644 "actions.{name} is 0; will be clamped to 1ms at runtime."
645 ));
646 } else if ms > 30_000 {
647 warnings.push(format!(
648 "actions.{name} is {}ms (>30s) — requests may block for a long time.",
649 ms
650 ));
651 }
652 }
653
654 let res = &self.actions.resilience;
656 if res.max_retries > 10 {
657 warnings.push(format!("actions.resilience.max_retries is {} (>10) — excessive retries may amplify failures.", res.max_retries));
658 }
659 if res.circuit_breaker_threshold == 0 {
660 warnings.push("actions.resilience.circuit_breaker_threshold is 0; circuit breaker will never trip.".to_string());
661 }
662
663 Ok(warnings)
664 }
665}
666
667impl Default for BrainConfig {
668 fn default() -> Self {
669 Self {
670 brain: GeneralConfig {
671 version: env!("CARGO_PKG_VERSION").to_string(),
672 data_dir: "~/.brain".to_string(),
673 },
674 storage: StorageConfig {
675 ruvector_path: "~/.brain/ruvector/".to_string(),
676 sqlite_path: "~/.brain/db/brain.db".to_string(),
677 hnsw: HnswConfig {
678 ef_construction: 200,
679 m: 16,
680 ef_search: 50,
681 },
682 },
683 llm: LlmConfig {
684 provider: "ollama".to_string(),
685 model: "qwen2.5-coder:7b".to_string(),
686 base_url: "http://localhost:11434".to_string(),
687 temperature: 0.7,
688 max_tokens: 4096,
689 api_key: String::new(),
690 },
691 embedding: EmbeddingConfig {
692 model: "nomic-embed-text".to_string(),
693 dimensions: 768,
694 },
695 memory: MemoryConfig {
696 episodic: EpisodicConfig {
697 max_entries: 100_000,
698 retention_days: 365,
699 },
700 semantic: SemanticConfig {
701 similarity_threshold: 0.65,
702 max_results: 20,
703 },
704 search: SearchConfig {
705 hybrid_weight: 0.7,
706 rrf_k: 60,
707 pre_fusion_limit: 50,
708 importance_weight: 0.3,
709 recency_weight: 0.2,
710 decay_rate: 0.01,
711 },
712 consolidation: ConsolidationConfig {
713 enabled: true,
714 interval_hours: 24,
715 forgetting_threshold: 0.05,
716 },
717 },
718 encryption: EncryptionConfig { enabled: false }, security: SecurityConfig {
720 exec_allowlist: vec![
721 "ls".into(),
722 "grep".into(),
723 "find".into(),
724 "git".into(),
725 "cargo".into(),
726 "rustc".into(),
727 ],
728 exec_timeout_seconds: 30,
729 },
730 actions: ActionsConfig {
731 web_search: WebSearchActionConfig {
732 enabled: true,
733 provider: WebSearchProvider::Searxng,
734 endpoint: "http://localhost:8888".to_string(),
735 api_key: String::new(),
736 timeout_ms: 3_000,
737 default_top_k: 5,
738 },
739 scheduling: SchedulingActionConfig {
740 enabled: false,
741 mode: SchedulingMode::PersistOnly,
742 },
743 messaging: MessagingActionConfig {
744 enabled: false,
745 timeout_ms: 3_000,
746 channels: HashMap::new(),
747 },
748 resilience: ResilienceConfig::default(),
749 },
750 proactivity: ProactivityConfig {
751 enabled: false,
752 max_per_day: 5,
753 min_interval_minutes: 60,
754 quiet_hours: QuietHoursConfig {
755 start: "22:00".to_string(),
756 end: "08:00".to_string(),
757 timezone: "UTC".to_string(),
758 },
759 delivery: DeliveryConfig::default(),
760 open_loop: OpenLoopDetectionConfig::default(),
761 },
762 adapters: AdaptersConfig {
763 http: HttpAdapterConfig {
764 enabled: true,
765 host: "127.0.0.1".to_string(),
766 port: 19789,
767 cors: true,
768 },
769 ws: WebSocketAdapterConfig {
770 enabled: true,
771 port: 19790,
772 },
773 mcp: McpAdapterConfig {
774 enabled: true,
775 stdio: true,
776 http: true,
777 port: 19791,
778 },
779 grpc: GrpcAdapterConfig {
780 enabled: true,
781 port: 19792,
782 },
783 },
784 access: AccessConfig {
785 api_keys: vec![ApiKeyConfig {
786 key: Self::generate_api_key(),
787 name: "Default Key".to_string(),
788 permissions: vec!["read".to_string(), "write".to_string()],
789 }],
790 },
791 }
792 }
793}
794
795fn expand_tilde(path: &str) -> PathBuf {
797 if let Some(rest) = path.strip_prefix("~/") {
798 if let Some(home) = dirs_home() {
799 return home.join(rest);
800 }
801 }
802 PathBuf::from(path)
803}
804
805fn dirs_home() -> Option<PathBuf> {
807 std::env::var_os("HOME").map(PathBuf::from)
808}
809
810#[cfg(test)]
811mod tests {
812 use super::*;
813
814 #[test]
815 fn test_default_config() {
816 let config = BrainConfig::default();
817 assert_eq!(config.brain.data_dir, "~/.brain");
818 assert_eq!(config.llm.provider, "ollama");
819 assert_eq!(config.embedding.dimensions, 768); assert!(!config.encryption.enabled); assert_eq!(
822 config.actions.web_search.provider,
823 WebSearchProvider::Searxng
824 );
825 assert_eq!(config.actions.scheduling.mode, SchedulingMode::PersistOnly);
826 assert!(!config.proactivity.enabled);
827 assert!(config.adapters.http.enabled);
828 }
829
830 #[test]
831 fn test_expand_tilde() {
832 let expanded = expand_tilde("~/.brain");
833 assert!(!expanded.to_str().unwrap().starts_with('~'));
834 assert!(expanded.to_str().unwrap().ends_with(".brain"));
835 }
836
837 #[test]
838 fn test_data_dir_paths() {
839 let config = BrainConfig::default();
840 let data = config.data_dir();
841 assert!(data.to_str().unwrap().ends_with(".brain"));
842 assert!(config.sqlite_path().to_str().unwrap().ends_with("brain.db"));
843 assert!(config
844 .ruvector_path()
845 .to_str()
846 .unwrap()
847 .ends_with("ruvector"));
848 }
849
850 #[test]
851 fn test_load_from_defaults() {
852 use figment::providers::Serialized;
853 let figment = Figment::new().merge(Serialized::defaults(BrainConfig::default()));
855 let config: BrainConfig = figment.extract().unwrap();
856 assert_eq!(config.llm.model, "qwen2.5-coder:7b");
857 assert_eq!(config.memory.search.rrf_k, 60);
858 assert_eq!(config.memory.search.pre_fusion_limit, 50);
859 assert!((config.memory.search.importance_weight - 0.3).abs() < f64::EPSILON);
860 assert!((config.memory.search.recency_weight - 0.2).abs() < f64::EPSILON);
861 assert!((config.memory.search.decay_rate - 0.01).abs() < f64::EPSILON);
862 }
863
864 fn writable_test_data_dir() -> String {
868 std::env::temp_dir()
869 .join("brain-core-tests")
870 .to_string_lossy()
871 .to_string()
872 }
873
874 fn validated_config() -> BrainConfig {
876 let mut c = BrainConfig::default();
877 c.brain.data_dir = writable_test_data_dir();
878 c.access.api_keys.clear();
879 c
880 }
881
882 #[test]
883 fn test_validate_generated_key_no_warning() {
884 let mut config = BrainConfig::default();
886 config.brain.data_dir = writable_test_data_dir();
887 let warnings = config.validate().expect("default config should be valid");
888 assert!(
889 !warnings.iter().any(|w| w.contains("No API keys")),
890 "should not have empty-keys warning with a generated key, got: {:?}",
891 warnings
892 );
893 }
894
895 #[test]
896 fn test_validate_no_api_keys_warning() {
897 let config = validated_config();
898 let warnings = config.validate().expect("should be valid");
899 assert!(
900 warnings.iter().any(|w| w.contains("No API keys")),
901 "expected no-api-keys warning, got: {:?}",
902 warnings
903 );
904 }
905
906 #[test]
907 fn test_validate_port_conflict_is_hard_error() {
908 let mut config = validated_config();
909 config.adapters.ws.port = config.adapters.http.port;
911 let err = config
912 .validate()
913 .expect_err("should fail with port conflict");
914 assert!(
915 err.contains("Port conflict"),
916 "unexpected error message: {err}"
917 );
918 }
919
920 #[test]
921 fn test_validate_bad_llm_url_is_hard_error() {
922 let mut config = validated_config();
923 config.llm.base_url = "ftp://invalid.example.com".to_string();
924 let err = config.validate().expect_err("should fail with bad URL");
925 assert!(
926 err.contains("Invalid LLM base_url"),
927 "unexpected error: {err}"
928 );
929 }
930
931 #[test]
932 fn test_validate_high_temperature_warning() {
933 let mut config = validated_config();
934 config.llm.temperature = 2.0;
935 let warnings = config.validate().expect("should be valid");
936 assert!(
937 warnings.iter().any(|w| w.contains("temperature")),
938 "expected temperature warning, got: {:?}",
939 warnings
940 );
941 }
942
943 #[test]
944 fn test_validate_consolidation_interval_zero_warning() {
945 let mut config = validated_config();
946 config.memory.consolidation.enabled = true;
947 config.memory.consolidation.interval_hours = 0;
948 let warnings = config.validate().expect("should be valid");
949 assert!(
950 warnings.iter().any(|w| w.contains("interval_hours")),
951 "expected interval warning, got: {:?}",
952 warnings
953 );
954 }
955
956 #[test]
957 fn test_actions_defaults_deserialize() {
958 let config = BrainConfig::load().expect("embedded defaults should load");
959 assert!(config.actions.web_search.enabled);
960 assert_eq!(
961 config.actions.web_search.provider,
962 WebSearchProvider::Searxng
963 );
964 assert_eq!(config.actions.web_search.default_top_k, 5);
965 assert_eq!(config.actions.scheduling.mode, SchedulingMode::PersistOnly);
966 assert!(!config.actions.messaging.enabled);
967 }
968
969 #[test]
970 fn test_validate_actions_warning_custom_without_endpoint() {
971 let mut config = validated_config();
972 config.actions.web_search.enabled = true;
973 config.actions.web_search.provider = WebSearchProvider::Custom;
974 config.actions.web_search.endpoint.clear();
975 config.actions.messaging.enabled = true;
976 config.actions.messaging.channels.clear();
977 let warnings = config.validate().expect("config should still be valid");
978 assert!(warnings.iter().any(|w| w.contains("'custom'")));
979 assert!(warnings.iter().any(|w| w.contains("messaging")));
980 }
981
982 #[test]
983 fn test_validate_tavily_without_api_key_warning() {
984 let mut config = validated_config();
985 config.actions.web_search.enabled = true;
986 config.actions.web_search.provider = WebSearchProvider::Tavily;
987 config.actions.web_search.api_key.clear();
988 let warnings = config.validate().expect("config should still be valid");
989 assert!(
990 warnings
991 .iter()
992 .any(|w| w.contains("'tavily'") && w.contains("api_key")),
993 "expected tavily api_key warning, got: {:?}",
994 warnings
995 );
996 }
997
998 #[test]
999 fn test_validate_searxng_no_web_search_warning() {
1000 let mut config = validated_config();
1001 config.actions.web_search.enabled = true;
1002 config.actions.web_search.provider = WebSearchProvider::Searxng;
1003 let warnings = config.validate().expect("config should still be valid");
1004 assert!(
1005 !warnings.iter().any(|w| w.contains("web_search")),
1006 "SearXNG with default endpoint should not trigger web_search warning, got: {:?}",
1007 warnings
1008 );
1009 }
1010
1011 #[test]
1012 fn test_validate_http_and_https_urls_accepted() {
1013 let mut config = validated_config();
1014 config.llm.base_url = "https://api.example.com/v1".to_string();
1015 assert!(config.validate().is_ok());
1016
1017 config.llm.base_url = "http://localhost:11434".to_string();
1018 assert!(config.validate().is_ok());
1019 }
1020
1021 #[test]
1022 fn test_validate_all_unique_ports_ok() {
1023 let config = validated_config();
1024 assert!(config.validate().is_ok());
1026 }
1027
1028 #[test]
1029 fn test_validate_timeout_zero_warning() {
1030 let mut config = validated_config();
1031 config.actions.web_search.timeout_ms = 0;
1032 let warnings = config.validate().expect("should be valid");
1033 assert!(
1034 warnings
1035 .iter()
1036 .any(|w| w.contains("timeout_ms") && w.contains("0")),
1037 "expected timeout_ms=0 warning, got: {:?}",
1038 warnings
1039 );
1040 }
1041
1042 #[test]
1043 fn test_validate_timeout_too_high_warning() {
1044 let mut config = validated_config();
1045 config.actions.messaging.timeout_ms = 60_000;
1046 let warnings = config.validate().expect("should be valid");
1047 assert!(
1048 warnings
1049 .iter()
1050 .any(|w| w.contains("timeout_ms") && w.contains("60000")),
1051 "expected high timeout warning, got: {:?}",
1052 warnings
1053 );
1054 }
1055
1056 #[test]
1057 fn test_validate_resilience_max_retries_warning() {
1058 let mut config = validated_config();
1059 config.actions.resilience.max_retries = 15;
1060 let warnings = config.validate().expect("should be valid");
1061 assert!(
1062 warnings
1063 .iter()
1064 .any(|w| w.contains("max_retries") && w.contains("15")),
1065 "expected max_retries warning, got: {:?}",
1066 warnings
1067 );
1068 }
1069
1070 #[test]
1071 fn test_validate_resilience_threshold_zero_warning() {
1072 let mut config = validated_config();
1073 config.actions.resilience.circuit_breaker_threshold = 0;
1074 let warnings = config.validate().expect("should be valid");
1075 assert!(
1076 warnings
1077 .iter()
1078 .any(|w| w.contains("circuit_breaker_threshold")),
1079 "expected circuit_breaker_threshold=0 warning, got: {:?}",
1080 warnings
1081 );
1082 }
1083
1084 #[test]
1085 fn test_resilience_defaults() {
1086 let res = ResilienceConfig::default();
1087 assert_eq!(res.max_retries, 2);
1088 assert_eq!(res.retry_base_ms, 500);
1089 assert_eq!(res.circuit_breaker_threshold, 5);
1090 assert_eq!(res.circuit_breaker_cooldown_secs, 60);
1091 }
1092
1093 #[test]
1094 fn test_channel_config_old_format_compat() {
1095 let yaml = r#"
1097 enabled: false
1098 timeout_ms: 3000
1099 channels:
1100 alerts: "https://example.com/hook"
1101 ops: "https://slack.example.com/webhook"
1102 "#;
1103 let cfg: MessagingActionConfig =
1104 serde_yaml::from_str(yaml).expect("old format should deserialize");
1105 assert_eq!(cfg.channels.len(), 2);
1106 assert_eq!(cfg.channels["alerts"].url, "https://example.com/hook");
1107 assert!(cfg.channels["alerts"].body.is_empty());
1108 assert!(cfg.channels["alerts"].headers.is_empty());
1109 }
1110
1111 #[test]
1112 fn test_channel_config_new_format() {
1113 let yaml = r#"
1114 enabled: true
1115 timeout_ms: 3000
1116 channels:
1117 alerts:
1118 url: "https://hooks.slack.com/services/T/B/x"
1119 body: '{"text": "{{content}}"}'
1120 headers:
1121 Authorization: "Bearer tok123"
1122 "#;
1123 let cfg: MessagingActionConfig =
1124 serde_yaml::from_str(yaml).expect("new format should deserialize");
1125 assert_eq!(cfg.channels.len(), 1);
1126 let ch = &cfg.channels["alerts"];
1127 assert_eq!(ch.url, "https://hooks.slack.com/services/T/B/x");
1128 assert_eq!(ch.body, r#"{"text": "{{content}}"}"#);
1129 assert_eq!(ch.headers["Authorization"], "Bearer tok123");
1130 }
1131
1132 #[test]
1133 fn test_channel_config_mixed_format() {
1134 let yaml = r#"
1135 enabled: true
1136 timeout_ms: 3000
1137 channels:
1138 simple: "https://example.com/hook"
1139 custom:
1140 url: "https://discord.com/api/webhooks/123/abc"
1141 body: '{"content": "{{content}}"}'
1142 "#;
1143 let cfg: MessagingActionConfig =
1144 serde_yaml::from_str(yaml).expect("mixed format should deserialize");
1145 assert_eq!(cfg.channels.len(), 2);
1146 assert_eq!(cfg.channels["simple"].url, "https://example.com/hook");
1147 assert!(cfg.channels["simple"].body.is_empty());
1148 let custom = &cfg.channels["custom"];
1149 assert_eq!(custom.url, "https://discord.com/api/webhooks/123/abc");
1150 assert!(!custom.body.is_empty());
1151 assert!(custom.headers.is_empty());
1152 }
1153
1154 #[test]
1155 fn test_validate_channel_empty_url_warning() {
1156 let mut config = validated_config();
1157 config.actions.messaging.enabled = true;
1158 config.actions.messaging.channels.insert(
1159 "bad".into(),
1160 ChannelConfig {
1161 url: "".into(),
1162 body: String::new(),
1163 headers: HashMap::new(),
1164 },
1165 );
1166 let warnings = config.validate().expect("should be valid");
1167 assert!(
1168 warnings
1169 .iter()
1170 .any(|w| w.contains("channels.bad") && w.contains("url is empty")),
1171 "expected empty-url warning, got: {:?}",
1172 warnings
1173 );
1174 }
1175}