Skip to main content

zeph_config/
loader.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4use std::path::Path;
5
6use crate::error::ConfigError;
7use crate::root::Config;
8
9impl Config {
10    /// Load configuration from a TOML file with env var overrides.
11    ///
12    /// Falls back to sensible defaults when the file does not exist.
13    ///
14    /// # Errors
15    ///
16    /// Returns an error if the file exists but cannot be read or parsed.
17    pub fn load(path: &Path) -> Result<Self, ConfigError> {
18        let mut config = if path.exists() {
19            let content = std::fs::read_to_string(path)?;
20            toml::from_str::<Self>(&content)?
21        } else {
22            Self::default()
23        };
24
25        config.apply_env_overrides();
26        config.normalize_legacy_runtime_defaults();
27        Ok(config)
28    }
29
30    /// Validate configuration values are within sane bounds.
31    ///
32    /// # Errors
33    ///
34    /// Returns an error if any value is out of range.
35    #[allow(clippy::too_many_lines)]
36    pub fn validate(&self) -> Result<(), ConfigError> {
37        if self.memory.history_limit > 10_000 {
38            return Err(ConfigError::Validation(format!(
39                "history_limit must be <= 10000, got {}",
40                self.memory.history_limit
41            )));
42        }
43        if self.memory.context_budget_tokens > 1_000_000 {
44            return Err(ConfigError::Validation(format!(
45                "context_budget_tokens must be <= 1000000, got {}",
46                self.memory.context_budget_tokens
47            )));
48        }
49        if self.agent.max_tool_iterations > 100 {
50            return Err(ConfigError::Validation(format!(
51                "max_tool_iterations must be <= 100, got {}",
52                self.agent.max_tool_iterations
53            )));
54        }
55        if self.a2a.rate_limit == 0 {
56            return Err(ConfigError::Validation("a2a.rate_limit must be > 0".into()));
57        }
58        if self.gateway.rate_limit == 0 {
59            return Err(ConfigError::Validation(
60                "gateway.rate_limit must be > 0".into(),
61            ));
62        }
63        if self.gateway.max_body_size > 10_485_760 {
64            return Err(ConfigError::Validation(format!(
65                "gateway.max_body_size must be <= 10485760 (10 MiB), got {}",
66                self.gateway.max_body_size
67            )));
68        }
69        if self.memory.token_safety_margin <= 0.0 {
70            return Err(ConfigError::Validation(format!(
71                "token_safety_margin must be > 0.0, got {}",
72                self.memory.token_safety_margin
73            )));
74        }
75        if self.memory.tool_call_cutoff == 0 {
76            return Err(ConfigError::Validation(
77                "tool_call_cutoff must be >= 1".into(),
78            ));
79        }
80        if let crate::memory::CompressionStrategy::Proactive {
81            threshold_tokens,
82            max_summary_tokens,
83        } = &self.memory.compression.strategy
84        {
85            if *threshold_tokens < 1_000 {
86                return Err(ConfigError::Validation(format!(
87                    "compression.threshold_tokens must be >= 1000, got {threshold_tokens}"
88                )));
89            }
90            if *max_summary_tokens < 128 {
91                return Err(ConfigError::Validation(format!(
92                    "compression.max_summary_tokens must be >= 128, got {max_summary_tokens}"
93                )));
94            }
95        }
96        if !self.memory.soft_compaction_threshold.is_finite()
97            || self.memory.soft_compaction_threshold <= 0.0
98            || self.memory.soft_compaction_threshold >= 1.0
99        {
100            return Err(ConfigError::Validation(format!(
101                "soft_compaction_threshold must be in (0.0, 1.0) exclusive, got {}",
102                self.memory.soft_compaction_threshold
103            )));
104        }
105        if !self.memory.hard_compaction_threshold.is_finite()
106            || self.memory.hard_compaction_threshold <= 0.0
107            || self.memory.hard_compaction_threshold >= 1.0
108        {
109            return Err(ConfigError::Validation(format!(
110                "hard_compaction_threshold must be in (0.0, 1.0) exclusive, got {}",
111                self.memory.hard_compaction_threshold
112            )));
113        }
114        if self.memory.soft_compaction_threshold >= self.memory.hard_compaction_threshold {
115            return Err(ConfigError::Validation(format!(
116                "soft_compaction_threshold ({}) must be less than hard_compaction_threshold ({})",
117                self.memory.soft_compaction_threshold, self.memory.hard_compaction_threshold,
118            )));
119        }
120        if self.memory.graph.temporal_decay_rate < 0.0
121            || self.memory.graph.temporal_decay_rate > 10.0
122        {
123            return Err(ConfigError::Validation(format!(
124                "memory.graph.temporal_decay_rate must be in [0.0, 10.0], got {}",
125                self.memory.graph.temporal_decay_rate
126            )));
127        }
128        if self.memory.compression.probe.enabled {
129            let probe = &self.memory.compression.probe;
130            if !probe.threshold.is_finite() || probe.threshold <= 0.0 || probe.threshold > 1.0 {
131                return Err(ConfigError::Validation(format!(
132                    "memory.compression.probe.threshold must be in (0.0, 1.0], got {}",
133                    probe.threshold
134                )));
135            }
136            if !probe.hard_fail_threshold.is_finite()
137                || probe.hard_fail_threshold < 0.0
138                || probe.hard_fail_threshold >= 1.0
139            {
140                return Err(ConfigError::Validation(format!(
141                    "memory.compression.probe.hard_fail_threshold must be in [0.0, 1.0), got {}",
142                    probe.hard_fail_threshold
143                )));
144            }
145            if probe.hard_fail_threshold >= probe.threshold {
146                return Err(ConfigError::Validation(format!(
147                    "memory.compression.probe.hard_fail_threshold ({}) must be less than \
148                     memory.compression.probe.threshold ({})",
149                    probe.hard_fail_threshold, probe.threshold
150                )));
151            }
152            if probe.max_questions < 1 {
153                return Err(ConfigError::Validation(
154                    "memory.compression.probe.max_questions must be >= 1".into(),
155                ));
156            }
157            if probe.timeout_secs < 1 {
158                return Err(ConfigError::Validation(
159                    "memory.compression.probe.timeout_secs must be >= 1".into(),
160                ));
161            }
162        }
163        // MCP server validation
164        {
165            use std::collections::HashSet;
166            let mut seen_oauth_vault_keys: HashSet<String> = HashSet::new();
167            for s in &self.mcp.servers {
168                // headers and oauth are mutually exclusive
169                if !s.headers.is_empty() && s.oauth.as_ref().is_some_and(|o| o.enabled) {
170                    return Err(ConfigError::Validation(format!(
171                        "MCP server '{}': cannot use both 'headers' and 'oauth' simultaneously",
172                        s.id
173                    )));
174                }
175                // vault key collision detection
176                if s.oauth.as_ref().is_some_and(|o| o.enabled) {
177                    let key = format!("ZEPH_MCP_OAUTH_{}", s.id.to_uppercase().replace('-', "_"));
178                    if !seen_oauth_vault_keys.insert(key.clone()) {
179                        return Err(ConfigError::Validation(format!(
180                            "MCP server '{}' has vault key collision ('{key}'): another server \
181                             with the same normalized ID already uses this key",
182                            s.id
183                        )));
184                    }
185                }
186            }
187        }
188
189        self.experiments
190            .validate()
191            .map_err(ConfigError::Validation)?;
192
193        if self.orchestration.plan_cache.enabled {
194            self.orchestration
195                .plan_cache
196                .validate()
197                .map_err(ConfigError::Validation)?;
198        }
199
200        let ct = self.orchestration.completeness_threshold;
201        if !ct.is_finite() || !(0.0..=1.0).contains(&ct) {
202            return Err(ConfigError::Validation(format!(
203                "orchestration.completeness_threshold must be in [0.0, 1.0], got {ct}"
204            )));
205        }
206
207        // Focus config validation
208        if self.agent.focus.compression_interval == 0 {
209            return Err(ConfigError::Validation(
210                "agent.focus.compression_interval must be >= 1".into(),
211            ));
212        }
213        if self.agent.focus.min_messages_per_focus == 0 {
214            return Err(ConfigError::Validation(
215                "agent.focus.min_messages_per_focus must be >= 1".into(),
216            ));
217        }
218
219        // SideQuest config validation
220        if self.memory.sidequest.interval_turns == 0 {
221            return Err(ConfigError::Validation(
222                "memory.sidequest.interval_turns must be >= 1".into(),
223            ));
224        }
225        if !self.memory.sidequest.max_eviction_ratio.is_finite()
226            || self.memory.sidequest.max_eviction_ratio <= 0.0
227            || self.memory.sidequest.max_eviction_ratio > 1.0
228        {
229            return Err(ConfigError::Validation(format!(
230                "memory.sidequest.max_eviction_ratio must be in (0.0, 1.0], got {}",
231                self.memory.sidequest.max_eviction_ratio
232            )));
233        }
234
235        let sct = self.llm.semantic_cache_threshold;
236        if !(sct.is_finite() && (0.0..=1.0).contains(&sct)) {
237            return Err(ConfigError::Validation(format!(
238                "llm.semantic_cache_threshold must be in [0.0, 1.0], got {sct} \
239                 (override via ZEPH_LLM_SEMANTIC_CACHE_THRESHOLD env var)"
240            )));
241        }
242
243        self.validate_provider_names()?;
244
245        Ok(())
246    }
247
248    #[allow(clippy::too_many_lines)]
249    fn validate_provider_names(&self) -> Result<(), ConfigError> {
250        use std::collections::HashSet;
251        let known: HashSet<String> = self
252            .llm
253            .providers
254            .iter()
255            .map(super::providers::ProviderEntry::effective_name)
256            .collect();
257
258        let fields: &[(&str, &crate::providers::ProviderName)] = &[
259            (
260                "memory.tiers.scene_provider",
261                &self.memory.tiers.scene_provider,
262            ),
263            (
264                "memory.compression.compress_provider",
265                &self.memory.compression.compress_provider,
266            ),
267            (
268                "memory.consolidation.consolidation_provider",
269                &self.memory.consolidation.consolidation_provider,
270            ),
271            (
272                "memory.admission.admission_provider",
273                &self.memory.admission.admission_provider,
274            ),
275            (
276                "memory.admission.goal_utility_provider",
277                &self.memory.admission.goal_utility_provider,
278            ),
279            (
280                "memory.store_routing.routing_classifier_provider",
281                &self.memory.store_routing.routing_classifier_provider,
282            ),
283            (
284                "skills.learning.feedback_provider",
285                &self.skills.learning.feedback_provider,
286            ),
287            (
288                "skills.learning.arise_trace_provider",
289                &self.skills.learning.arise_trace_provider,
290            ),
291            (
292                "skills.learning.stem_provider",
293                &self.skills.learning.stem_provider,
294            ),
295            (
296                "skills.learning.erl_extract_provider",
297                &self.skills.learning.erl_extract_provider,
298            ),
299            (
300                "mcp.pruning.pruning_provider",
301                &self.mcp.pruning.pruning_provider,
302            ),
303            (
304                "mcp.tool_discovery.embedding_provider",
305                &self.mcp.tool_discovery.embedding_provider,
306            ),
307            (
308                "security.response_verification.verifier_provider",
309                &self.security.response_verification.verifier_provider,
310            ),
311            (
312                "orchestration.planner_provider",
313                &self.orchestration.planner_provider,
314            ),
315            (
316                "orchestration.verify_provider",
317                &self.orchestration.verify_provider,
318            ),
319            (
320                "orchestration.tool_provider",
321                &self.orchestration.tool_provider,
322            ),
323        ];
324
325        for (field, name) in fields {
326            if !name.is_empty() && !known.contains(name.as_str()) {
327                return Err(ConfigError::Validation(format!(
328                    "{field} = {:?} does not match any [[llm.providers]] entry",
329                    name.as_str()
330                )));
331            }
332        }
333
334        if let Some(triage) = self
335            .llm
336            .complexity_routing
337            .as_ref()
338            .and_then(|cr| cr.triage_provider.as_ref())
339            .filter(|t| !t.is_empty() && !known.contains(t.as_str()))
340        {
341            return Err(ConfigError::Validation(format!(
342                "llm.complexity_routing.triage_provider = {:?} does not match any \
343                 [[llm.providers]] entry",
344                triage.as_str()
345            )));
346        }
347
348        if let Some(embed) = self
349            .llm
350            .router
351            .as_ref()
352            .and_then(|r| r.bandit.as_ref())
353            .map(|b| &b.embedding_provider)
354            .filter(|p| !p.is_empty() && !known.contains(p.as_str()))
355        {
356            return Err(ConfigError::Validation(format!(
357                "llm.router.bandit.embedding_provider = {:?} does not match any \
358                 [[llm.providers]] entry",
359                embed.as_str()
360            )));
361        }
362
363        Ok(())
364    }
365
366    fn normalize_legacy_runtime_defaults(&mut self) {
367        use crate::defaults::{
368            default_debug_dir, default_log_file_path, default_skills_dir, default_sqlite_path,
369            is_legacy_default_debug_dir, is_legacy_default_log_file, is_legacy_default_skills_path,
370            is_legacy_default_sqlite_path,
371        };
372
373        if is_legacy_default_sqlite_path(&self.memory.sqlite_path) {
374            self.memory.sqlite_path = default_sqlite_path();
375        }
376
377        for skill_path in &mut self.skills.paths {
378            if is_legacy_default_skills_path(skill_path) {
379                *skill_path = default_skills_dir();
380            }
381        }
382
383        if is_legacy_default_debug_dir(&self.debug.output_dir) {
384            self.debug.output_dir = default_debug_dir();
385        }
386
387        if is_legacy_default_log_file(&self.logging.file) {
388            self.logging.file = default_log_file_path();
389        }
390    }
391}
392
393#[cfg(test)]
394mod tests {
395    use super::*;
396
397    fn config_with_sct(threshold: f32) -> Config {
398        let mut cfg = Config::default();
399        cfg.llm.semantic_cache_threshold = threshold;
400        cfg
401    }
402
403    #[test]
404    fn semantic_cache_threshold_valid_zero() {
405        assert!(config_with_sct(0.0).validate().is_ok());
406    }
407
408    #[test]
409    fn semantic_cache_threshold_valid_mid() {
410        assert!(config_with_sct(0.5).validate().is_ok());
411    }
412
413    #[test]
414    fn semantic_cache_threshold_valid_one() {
415        assert!(config_with_sct(1.0).validate().is_ok());
416    }
417
418    #[test]
419    fn semantic_cache_threshold_invalid_negative() {
420        let err = config_with_sct(-0.1).validate().unwrap_err();
421        assert!(
422            err.to_string().contains("semantic_cache_threshold"),
423            "unexpected error: {err}"
424        );
425    }
426
427    #[test]
428    fn semantic_cache_threshold_invalid_above_one() {
429        let err = config_with_sct(1.1).validate().unwrap_err();
430        assert!(
431            err.to_string().contains("semantic_cache_threshold"),
432            "unexpected error: {err}"
433        );
434    }
435
436    #[test]
437    fn semantic_cache_threshold_invalid_nan() {
438        let err = config_with_sct(f32::NAN).validate().unwrap_err();
439        assert!(
440            err.to_string().contains("semantic_cache_threshold"),
441            "unexpected error: {err}"
442        );
443    }
444
445    #[test]
446    fn semantic_cache_threshold_invalid_infinity() {
447        let err = config_with_sct(f32::INFINITY).validate().unwrap_err();
448        assert!(
449            err.to_string().contains("semantic_cache_threshold"),
450            "unexpected error: {err}"
451        );
452    }
453
454    #[test]
455    fn semantic_cache_threshold_invalid_neg_infinity() {
456        let err = config_with_sct(f32::NEG_INFINITY).validate().unwrap_err();
457        assert!(
458            err.to_string().contains("semantic_cache_threshold"),
459            "unexpected error: {err}"
460        );
461    }
462
463    fn probe_config(enabled: bool, threshold: f32, hard_fail_threshold: f32) -> Config {
464        let mut cfg = Config::default();
465        cfg.memory.compression.probe.enabled = enabled;
466        cfg.memory.compression.probe.threshold = threshold;
467        cfg.memory.compression.probe.hard_fail_threshold = hard_fail_threshold;
468        cfg
469    }
470
471    #[test]
472    fn probe_disabled_skips_validation() {
473        // Invalid thresholds when probe is disabled must not cause errors.
474        let cfg = probe_config(false, 0.0, 1.0);
475        assert!(cfg.validate().is_ok());
476    }
477
478    #[test]
479    fn probe_valid_thresholds() {
480        let cfg = probe_config(true, 0.6, 0.35);
481        assert!(cfg.validate().is_ok());
482    }
483
484    #[test]
485    fn probe_threshold_zero_invalid() {
486        let err = probe_config(true, 0.0, 0.0).validate().unwrap_err();
487        assert!(
488            err.to_string().contains("probe.threshold"),
489            "unexpected error: {err}"
490        );
491    }
492
493    #[test]
494    fn probe_hard_fail_threshold_above_one_invalid() {
495        let err = probe_config(true, 0.6, 1.0).validate().unwrap_err();
496        assert!(
497            err.to_string().contains("probe.hard_fail_threshold"),
498            "unexpected error: {err}"
499        );
500    }
501
502    #[test]
503    fn probe_hard_fail_gte_threshold_invalid() {
504        let err = probe_config(true, 0.3, 0.9).validate().unwrap_err();
505        assert!(
506            err.to_string().contains("probe.hard_fail_threshold"),
507            "unexpected error: {err}"
508        );
509    }
510
511    fn config_with_completeness_threshold(ct: f32) -> Config {
512        let mut cfg = Config::default();
513        cfg.orchestration.completeness_threshold = ct;
514        cfg
515    }
516
517    #[test]
518    fn completeness_threshold_valid_zero() {
519        assert!(config_with_completeness_threshold(0.0).validate().is_ok());
520    }
521
522    #[test]
523    fn completeness_threshold_valid_default() {
524        assert!(config_with_completeness_threshold(0.7).validate().is_ok());
525    }
526
527    #[test]
528    fn completeness_threshold_valid_one() {
529        assert!(config_with_completeness_threshold(1.0).validate().is_ok());
530    }
531
532    #[test]
533    fn completeness_threshold_invalid_negative() {
534        let err = config_with_completeness_threshold(-0.1)
535            .validate()
536            .unwrap_err();
537        assert!(
538            err.to_string().contains("completeness_threshold"),
539            "unexpected error: {err}"
540        );
541    }
542
543    #[test]
544    fn completeness_threshold_invalid_above_one() {
545        let err = config_with_completeness_threshold(1.1)
546            .validate()
547            .unwrap_err();
548        assert!(
549            err.to_string().contains("completeness_threshold"),
550            "unexpected error: {err}"
551        );
552    }
553
554    #[test]
555    fn completeness_threshold_invalid_nan() {
556        let err = config_with_completeness_threshold(f32::NAN)
557            .validate()
558            .unwrap_err();
559        assert!(
560            err.to_string().contains("completeness_threshold"),
561            "unexpected error: {err}"
562        );
563    }
564
565    #[test]
566    fn completeness_threshold_invalid_infinity() {
567        let err = config_with_completeness_threshold(f32::INFINITY)
568            .validate()
569            .unwrap_err();
570        assert!(
571            err.to_string().contains("completeness_threshold"),
572            "unexpected error: {err}"
573        );
574    }
575
576    fn config_with_provider(name: &str) -> Config {
577        let mut cfg = Config::default();
578        cfg.llm.providers.push(crate::providers::ProviderEntry {
579            provider_type: crate::providers::ProviderKind::Ollama,
580            name: Some(name.into()),
581            ..Default::default()
582        });
583        cfg
584    }
585
586    #[test]
587    fn validate_provider_names_all_empty_ok() {
588        let cfg = Config::default();
589        assert!(cfg.validate_provider_names().is_ok());
590    }
591
592    #[test]
593    fn validate_provider_names_matching_provider_ok() {
594        let mut cfg = config_with_provider("fast");
595        cfg.memory.admission.admission_provider = crate::providers::ProviderName::new("fast");
596        assert!(cfg.validate_provider_names().is_ok());
597    }
598
599    #[test]
600    fn validate_provider_names_unknown_provider_err() {
601        let mut cfg = config_with_provider("fast");
602        cfg.memory.admission.admission_provider =
603            crate::providers::ProviderName::new("nonexistent");
604        let err = cfg.validate_provider_names().unwrap_err();
605        let msg = err.to_string();
606        assert!(
607            msg.contains("admission_provider") && msg.contains("nonexistent"),
608            "unexpected error: {msg}"
609        );
610    }
611
612    #[test]
613    fn validate_provider_names_triage_provider_none_ok() {
614        let mut cfg = config_with_provider("fast");
615        cfg.llm.complexity_routing = Some(crate::providers::ComplexityRoutingConfig {
616            triage_provider: None,
617            ..Default::default()
618        });
619        assert!(cfg.validate_provider_names().is_ok());
620    }
621
622    #[test]
623    fn validate_provider_names_triage_provider_matching_ok() {
624        let mut cfg = config_with_provider("fast");
625        cfg.llm.complexity_routing = Some(crate::providers::ComplexityRoutingConfig {
626            triage_provider: Some(crate::providers::ProviderName::new("fast")),
627            ..Default::default()
628        });
629        assert!(cfg.validate_provider_names().is_ok());
630    }
631
632    #[test]
633    fn validate_provider_names_triage_provider_unknown_err() {
634        let mut cfg = config_with_provider("fast");
635        cfg.llm.complexity_routing = Some(crate::providers::ComplexityRoutingConfig {
636            triage_provider: Some(crate::providers::ProviderName::new("ghost")),
637            ..Default::default()
638        });
639        let err = cfg.validate_provider_names().unwrap_err();
640        let msg = err.to_string();
641        assert!(
642            msg.contains("triage_provider") && msg.contains("ghost"),
643            "unexpected error: {msg}"
644        );
645    }
646
647    // Regression test for issue #2599: TOML float values must deserialise without error
648    // across all config sections that contain f32/f64 fields.
649    #[test]
650    fn toml_float_fields_deserialise_correctly() {
651        let toml = r"
652[llm.router.reputation]
653enabled = true
654decay_factor = 0.95
655weight = 0.3
656
657[llm.router.bandit]
658enabled = false
659cost_weight = 0.3
660alpha = 1.0
661decay_factor = 0.99
662
663[skills]
664disambiguation_threshold = 0.25
665cosine_weight = 0.7
666";
667        // Wrap in a full Config to exercise the nested paths.
668        let wrapped = format!(
669            "{}\n{}",
670            toml,
671            r"[memory.semantic]
672mmr_lambda = 0.7
673"
674        );
675        // We only need the sub-structs to round-trip; build minimal wrappers.
676        let router: crate::providers::RouterConfig = toml::from_str(
677            r"[reputation]
678enabled = true
679decay_factor = 0.95
680weight = 0.3
681",
682        )
683        .expect("RouterConfig with float fields must deserialise");
684        assert!((router.reputation.unwrap().decay_factor - 0.95).abs() < f64::EPSILON);
685
686        let bandit: crate::providers::BanditConfig =
687            toml::from_str("cost_weight = 0.3\nalpha = 1.0\n")
688                .expect("BanditConfig with float fields must deserialise");
689        assert!((bandit.cost_weight - 0.3_f32).abs() < f32::EPSILON);
690
691        let semantic: crate::memory::SemanticConfig = toml::from_str("mmr_lambda = 0.7\n")
692            .expect("SemanticConfig with float fields must deserialise");
693        assert!((semantic.mmr_lambda - 0.7_f32).abs() < f32::EPSILON);
694
695        let skills: crate::features::SkillsConfig =
696            toml::from_str("disambiguation_threshold = 0.25\n")
697                .expect("SkillsConfig with float fields must deserialise");
698        assert!((skills.disambiguation_threshold - 0.25_f32).abs() < f32::EPSILON);
699
700        let _ = wrapped; // silence unused-variable lint
701    }
702}