Skip to main content

brainos_core/config/
loader.rs

1use std::{
2    collections::HashMap,
3    path::{Path, PathBuf},
4};
5
6use figment::{
7    providers::{Env, Format, Yaml},
8    Figment,
9};
10
11use super::*;
12
13impl BrainConfig {
14    /// Load configuration from all sources.
15    ///
16    /// Priority (highest wins):
17    /// 1. Environment variables (`BRAIN_LLM__MODEL=...`)
18    /// 2. User config (`~/.brain/config.yaml`)
19    /// 3. Embedded defaults (compiled into binary)
20    #[allow(clippy::result_large_err)]
21    pub fn load() -> Result<Self, figment::Error> {
22        Self::load_from(None)
23    }
24
25    /// Load configuration with an optional explicit config path.
26    #[allow(clippy::result_large_err)]
27    pub fn load_from(config_path: Option<&Path>) -> Result<Self, figment::Error> {
28        // Layer 1: Embedded defaults (always available, no file needed)
29        let mut figment = Figment::new().merge(Yaml::string(super::DEFAULT_CONFIG));
30
31        // Layer 2: User config (~/.brain/config.yaml)
32        let user_config = Self::user_config_path();
33        if user_config.exists() {
34            figment = figment.merge(Yaml::file(&user_config));
35        }
36
37        // Layer 3: Explicit config path (if provided)
38        if let Some(path) = config_path {
39            figment = figment.merge(Yaml::file(path));
40        }
41
42        // Layer 4: Environment variables (BRAIN_LLM__MODEL=...)
43        figment = figment.merge(Env::prefixed("BRAIN_").split("__"));
44
45        let mut cfg: Self = figment.extract()?;
46
47        // Post-load: if the user set legacy llm.{base_url,model,api_key} via
48        // env vars but `providers[]` is non-empty, the multi-provider path
49        // would silently ignore the env override. Forward those overrides
50        // onto providers[0] so `BRAIN_LLM__BASE_URL=...` does what users
51        // expect regardless of how the YAML is structured.
52        if !cfg.llm.providers.is_empty() {
53            #[allow(deprecated)]
54            {
55                if std::env::var("BRAIN_LLM__BASE_URL").is_ok() {
56                    cfg.llm.providers[0].base_url = cfg.llm.base_url.clone();
57                }
58                if std::env::var("BRAIN_LLM__MODEL").is_ok() {
59                    cfg.llm.providers[0].model = cfg.llm.model.clone();
60                }
61                if std::env::var("BRAIN_LLM__API_KEY").is_ok() {
62                    cfg.llm.providers[0].api_key = cfg.llm.api_key.clone();
63                }
64            }
65        }
66
67        Ok(cfg)
68    }
69
70    /// Resolve the data directory path, expanding `~` to the home directory.
71    pub fn data_dir(&self) -> PathBuf {
72        expand_tilde(&self.brain.data_dir)
73    }
74
75    /// Ensure the data directory and subdirectories exist.
76    pub fn ensure_data_dirs(&self) -> std::io::Result<()> {
77        let data_dir = self.data_dir();
78        let dirs = [
79            data_dir.clone(),
80            data_dir.join("db"),
81            data_dir.join("ruvector"),
82            data_dir.join("models"),
83            data_dir.join("logs"),
84            data_dir.join("exports"),
85        ];
86
87        for dir in &dirs {
88            std::fs::create_dir_all(dir)?;
89        }
90
91        Ok(())
92    }
93
94    /// Path to the SQLite database file.
95    pub fn sqlite_path(&self) -> PathBuf {
96        self.data_dir().join("db").join("brain.db")
97    }
98
99    /// Path to the RuVector directory.
100    pub fn ruvector_path(&self) -> PathBuf {
101        self.data_dir().join("ruvector")
102    }
103
104    /// Path to the models directory.
105    pub fn models_path(&self) -> PathBuf {
106        self.data_dir().join("models")
107    }
108
109    /// Check whether Brain has been initialized (data dir exists).
110    pub fn is_initialized() -> bool {
111        expand_tilde("~/.brain").exists()
112    }
113
114    /// Write the default config to `~/.brain/config.yaml`.
115    ///
116    /// Returns `(config_path, generated_api_key)`, or `None` if the file already
117    /// exists and `force` is false.
118    pub fn write_default_config(force: bool) -> std::io::Result<Option<(PathBuf, String)>> {
119        let config_path = Self::user_config_path();
120
121        if config_path.exists() && !force {
122            return Ok(None);
123        }
124
125        if let Some(parent) = config_path.parent() {
126            std::fs::create_dir_all(parent)?;
127        }
128
129        let api_key = Self::generate_api_key();
130        let config = super::DEFAULT_CONFIG.replace(
131            "api_keys: []",
132            &format!(
133                "api_keys:\n    - key: \"{}\"\n      name: \"Default Key\"\n      permissions: [read, write]",
134                api_key
135            ),
136        );
137
138        std::fs::write(&config_path, config)?;
139        Ok(Some((config_path, api_key)))
140    }
141
142    /// Generate a random 36-char API key with `brk_` prefix.
143    fn generate_api_key() -> String {
144        let mut buf = [0u8; 16];
145        getrandom::getrandom(&mut buf).expect("failed to obtain random bytes from OS");
146        let hex: String = buf.iter().map(|b| format!("{:02x}", b)).collect();
147        format!("brk_{}", hex)
148    }
149
150    /// Path to user config file.
151    ///
152    /// `BRAIN_CONFIG` env var overrides the default `~/.brain/config.yaml`,
153    /// useful for sandboxes, CI, and multi-config workflows.
154    pub fn user_config_path() -> PathBuf {
155        if let Ok(p) = std::env::var("BRAIN_CONFIG") {
156            if !p.trim().is_empty() {
157                return PathBuf::from(p);
158            }
159        }
160        expand_tilde("~/.brain/config.yaml")
161    }
162
163    /// Get the embedded default config content.
164    pub fn default_config_content() -> &'static str {
165        super::DEFAULT_CONFIG
166    }
167
168    /// Validate configuration and return a list of warnings.
169    pub fn validate(&self) -> Result<Vec<String>, String> {
170        let mut warnings: Vec<String> = Vec::new();
171
172        let mut ports: HashMap<u16, &str> = HashMap::new();
173        let adapter_ports = [
174            (self.adapters.http.port, "http"),
175            (self.adapters.ws.port, "ws"),
176            (self.adapters.mcp.port, "mcp"),
177            (self.adapters.grpc.port, "grpc"),
178        ];
179        for (port, name) in &adapter_ports {
180            if let Some(existing) = ports.insert(*port, name) {
181                return Err(format!(
182                    "Port conflict: adapters '{}' and '{}' both use port {}",
183                    existing, name, port
184                ));
185            }
186        }
187
188        #[allow(deprecated)]
189        let url = &self.llm.base_url;
190        if !url.starts_with("http://") && !url.starts_with("https://") {
191            return Err(format!(
192                "Invalid LLM base_url '{}': must start with http:// or https://",
193                url
194            ));
195        }
196
197        let data_dir = self.data_dir();
198        if data_dir.exists() {
199            let probe = data_dir.join(".brain_write_probe");
200            if std::fs::write(&probe, b"").is_err() {
201                return Err(format!(
202                    "Data directory '{}' is not writable",
203                    data_dir.display()
204                ));
205            }
206            let _ = std::fs::remove_file(&probe);
207        }
208
209        if self.access.api_keys.is_empty() {
210            return Err("No API keys configured. Run `brain init` to generate a config with a secure API key, or configure 'access.api_keys' manually.".to_string());
211        }
212
213        if self.llm.temperature > 1.5 {
214            warnings.push(format!(
215                "LLM temperature {:.1} is very high — responses may be unpredictable.",
216                self.llm.temperature
217            ));
218        }
219
220        // Issue 40: `llm.provider` is #[deprecated]. When `providers[]`
221        // is non-empty AND the legacy field is also set, the loader
222        // backfills providers[0] from the legacy shape — but operators
223        // typically don't realise the legacy field is no longer
224        // authoritative for LLM routing. Surface the overlap loudly so
225        // they migrate to `providers[]`.
226        #[allow(deprecated)]
227        let legacy_provider_set = !self.llm.provider.trim().is_empty();
228        if !self.llm.providers.is_empty() && legacy_provider_set {
229            warnings.push(
230                "Legacy `llm.provider` is set alongside `llm.providers[]`. \
231                 `llm.providers[]` is the authoritative routing surface; \
232                 the legacy field is only kept for embedder selection \
233                 and will be retired in a future release."
234                    .to_string(),
235            );
236        }
237
238        if self.memory.consolidation.enabled && self.memory.consolidation.interval_hours == 0 {
239            warnings.push("Consolidation interval_hours is 0 — consolidation will run immediately on every daemon wake-up, which may impact performance.".to_string());
240        }
241
242        if self.actions.web_search.enabled {
243            match self.actions.web_search.provider {
244                WebSearchProvider::Custom if self.actions.web_search.endpoint.trim().is_empty() => {
245                    warnings.push("Actions web_search provider is 'custom' but endpoint is empty; dispatches will fail with backend-not-configured.".to_string());
246                }
247                WebSearchProvider::Tavily if self.actions.web_search.api_key.trim().is_empty() => {
248                    warnings.push("Actions web_search provider is 'tavily' but api_key is empty; dispatches will fail.".to_string());
249                }
250                _ => {}
251            }
252        }
253
254        if self.actions.messaging.enabled {
255            if self.actions.messaging.channels.is_empty() {
256                if self.channel.transports.is_empty() && self.channel.relays.is_empty() {
257                    warnings.push("Actions messaging is enabled but neither actions.messaging.channels, channel.transports, nor channel.relays are configured; dispatches will fail.".to_string());
258                }
259            } else {
260                for (name, channel_cfg) in &self.actions.messaging.channels {
261                    if channel_cfg.url.trim().is_empty() {
262                        warnings.push(format!(
263                            "actions.messaging.channels.{name}: url is empty; dispatches to this channel will fail."
264                        ));
265                    }
266                }
267            }
268        }
269
270        for (name, ms) in [
271            ("web_search.timeout_ms", self.actions.web_search.timeout_ms),
272            ("messaging.timeout_ms", self.actions.messaging.timeout_ms),
273        ] {
274            if ms == 0 {
275                warnings.push(format!(
276                    "actions.{name} is 0; will be clamped to 1ms at runtime."
277                ));
278            } else if ms > 30_000 {
279                warnings.push(format!(
280                    "actions.{name} is {}ms (>30s) — requests may block for a long time.",
281                    ms
282                ));
283            }
284        }
285
286        let mut service_names: HashMap<&str, ()> = HashMap::new();
287        for svc in &self.monitoring.services {
288            if svc.name.trim().is_empty() {
289                warnings.push(
290                    "monitoring.services has an entry with an empty name; its alerts will be hard to attribute.".to_string(),
291                );
292            } else if service_names.insert(svc.name.as_str(), ()).is_some() {
293                warnings.push(format!(
294                    "monitoring.services has a duplicate name '{}'; both probe loops still run but their alerts are indistinguishable.",
295                    svc.name
296                ));
297            }
298            if svc.target.trim().is_empty() {
299                warnings.push(format!(
300                    "monitoring.services.{}: target is empty; this probe will always report the service as down.",
301                    svc.name
302                ));
303            } else if matches!(svc.kind, ServiceCheckKind::Http)
304                && !svc.target.starts_with("http://")
305                && !svc.target.starts_with("https://")
306            {
307                warnings.push(format!(
308                    "monitoring.services.{}: kind is 'http' but target '{}' is not an http(s) URL; probes will fail.",
309                    svc.name, svc.target
310                ));
311            } else if matches!(svc.kind, ServiceCheckKind::Tcp) && !svc.target.contains(':') {
312                warnings.push(format!(
313                    "monitoring.services.{}: kind is 'tcp' but target '{}' is not 'host:port'; probes will fail.",
314                    svc.name, svc.target
315                ));
316            }
317            if svc.interval_secs == 0 {
318                warnings.push(format!(
319                    "monitoring.services.{}: interval_secs is 0 — it will be clamped to 1s at runtime, probing in a tight loop.",
320                    svc.name
321                ));
322            }
323        }
324
325        let res = &self.actions.resilience;
326        if res.max_retries > 10 {
327            warnings.push(format!("actions.resilience.max_retries is {} (>10) — excessive retries may amplify failures.", res.max_retries));
328        }
329        if res.circuit_breaker_threshold == 0 {
330            warnings.push("actions.resilience.circuit_breaker_threshold is 0; circuit breaker will never trip.".to_string());
331        }
332
333        Ok(warnings)
334    }
335}
336
337impl Default for BrainConfig {
338    // `llm.provider` is `#[deprecated]` (Issue 40) but still load-bearing as
339    // the implicit single entry when `providers[]` is empty — populate it.
340    #[allow(deprecated)]
341    fn default() -> Self {
342        Self {
343            brain: GeneralConfig {
344                version: env!("CARGO_PKG_VERSION").to_string(),
345                data_dir: "~/.brain".to_string(),
346            },
347            storage: StorageConfig {
348                ruvector_path: "~/.brain/ruvector/".to_string(),
349                sqlite_path: "~/.brain/db/brain.db".to_string(),
350                hnsw: HnswConfig {
351                    ef_construction: 200,
352                    max_elements: HnswConfig::default_max_elements(),
353                    m: 16,
354                    ef_search: 50,
355                },
356            },
357            llm: LlmConfig {
358                provider: "ollama".to_string(),
359                model: "qwen2.5-coder:7b".to_string(),
360                base_url: "http://localhost:11434".to_string(),
361                temperature: 0.7,
362                max_tokens: 4096,
363                context_window: super::default_context_window(),
364                api_key: String::new(),
365                api_key_file: None,
366                providers: Vec::new(),
367            },
368            embedding: EmbeddingConfig {
369                model: "nomic-embed-text".to_string(),
370                dimensions: 768,
371            },
372            memory: MemoryConfig {
373                semantic: SemanticConfig {
374                    similarity_threshold: 0.65,
375                    max_results: 20,
376                },
377                search: SearchConfig {
378                    rrf_k: 60,
379                    pre_fusion_limit: 50,
380                    importance_weight: 0.3,
381                    recency_weight: 0.2,
382                    decay_rate: 0.01,
383                },
384                consolidation: ConsolidationConfig {
385                    enabled: true,
386                    interval_hours: 24,
387                    forgetting_threshold: 0.05,
388                },
389            },
390            encryption: EncryptionConfig { enabled: false },
391            security: SecurityConfig {
392                exec_allowlist: vec![
393                    // Read-only inspection
394                    "ls".into(),
395                    "cat".into(),
396                    "head".into(),
397                    "tail".into(),
398                    "wc".into(),
399                    "file".into(),
400                    "stat".into(),
401                    // Text processing
402                    "grep".into(),
403                    "find".into(),
404                    "sort".into(),
405                    "uniq".into(),
406                    "cut".into(),
407                    "awk".into(),
408                    "sed".into(),
409                    // Shell discovery / pathing
410                    "which".into(),
411                    "command".into(),
412                    "type".into(),
413                    "test".into(),
414                    "basename".into(),
415                    "dirname".into(),
416                    "realpath".into(),
417                    // Output
418                    "echo".into(),
419                    "printf".into(),
420                    "true".into(),
421                    "false".into(),
422                    // Toolchain
423                    "git".into(),
424                    "cargo".into(),
425                    "rustc".into(),
426                    "rustup".into(),
427                    // Shell wrapper for the shell-mode execution tier
428                    // (see SandboxCommand::shell). The per-binary
429                    // allowlist is bypassed for commands wrapped by
430                    // `sh -c` — rlimits + Seatbelt + timeout +
431                    // forbidden_commands still apply.
432                    "sh".into(),
433                ],
434                exec_timeout_seconds: 30,
435                // Empty → handler defaults to `$HOME`. Configure
436                // `security.allowed_paths: ["~/code", "~/work"]` to
437                // restrict further.
438                allowed_paths: Vec::new(),
439            },
440            actions: ActionsConfig {
441                web_search: WebSearchActionConfig {
442                    // On by default. DuckDuckGo HTML scraping is the
443                    // zero-config built-in so first-run has working web
444                    // search without Docker or an API key.
445                    enabled: true,
446                    provider: WebSearchProvider::DuckDuckGo,
447                    endpoint: "http://localhost:8888".to_string(),
448                    api_key: String::new(),
449                    timeout_ms: 3_000,
450                    default_top_k: 5,
451                },
452                scheduling: SchedulingActionConfig {
453                    enabled: false,
454                    mode: SchedulingMode::PersistOnly,
455                },
456                messaging: MessagingActionConfig {
457                    enabled: false,
458                    timeout_ms: 3_000,
459                    channels: HashMap::new(),
460                },
461                resilience: ResilienceConfig::default(),
462            },
463            proactivity: ProactivityConfig {
464                // Synced to `default.yaml`: programmatic `BrainConfig::default()`
465                // must match the embedded YAML so the loader and the struct
466                // produce the same shape (Issue 36).
467                enabled: true,
468                max_per_day: 2,
469                min_interval_minutes: 60,
470                quiet_hours: QuietHoursConfig {
471                    start: "20:00".to_string(),
472                    end: "10:00".to_string(),
473                    timezone: "UTC".to_string(),
474                },
475                delivery: DeliveryConfig::default(),
476                open_loop: OpenLoopDetectionConfig::default(),
477            },
478            adapters: AdaptersConfig {
479                http: HttpAdapterConfig {
480                    enabled: true,
481                    host: "127.0.0.1".to_string(),
482                    port: 19789,
483                    cors: true,
484                    sse_redact_previews: false,
485                },
486                ws: WebSocketAdapterConfig {
487                    enabled: true,
488                    port: 19790,
489                },
490                mcp: McpAdapterConfig {
491                    enabled: true,
492                    port: 19791,
493                },
494                grpc: GrpcAdapterConfig {
495                    enabled: true,
496                    port: 19792,
497                },
498                terminal: TerminalAdapterConfig::default_enabled(),
499            },
500            access: AccessConfig {
501                api_keys: vec![ApiKeyConfig {
502                    key: Self::generate_api_key(),
503                    name: "Default Key".to_string(),
504                    permissions: vec!["read".to_string(), "write".to_string()],
505                    agent_id: None,
506                }],
507                rate_limit: ClientRateLimitConfig::default(),
508            },
509            channel: ChannelIntelligenceConfig::default(),
510            agents: AgentsConfig::default(),
511            confirm: ConfirmConfig::default(),
512            identity: identity::IdentityConfig::default(),
513            reflex: ReflexConfig::default(),
514            logging: crate::config::LoggingConfig::default(),
515            learning: crate::config::LearningConfig::default(),
516            observability: crate::config::ObservabilityConfig::default(),
517            monitoring: crate::config::MonitoringConfig::default(),
518        }
519    }
520}
521
522pub(crate) fn expand_tilde(path: &str) -> PathBuf {
523    if let Some(rest) = path.strip_prefix("~/") {
524        if let Some(home) = dirs_home() {
525            return home.join(rest);
526        }
527    }
528    PathBuf::from(path)
529}
530
531fn dirs_home() -> Option<PathBuf> {
532    std::env::var_os("HOME").map(PathBuf::from)
533}