Skip to main content

zeph_config/
providers.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4use std::fmt;
5
6use serde::{Deserialize, Serialize};
7use zeph_llm::{GeminiThinkingLevel, ThinkingConfig};
8
9/// Newtype wrapper for a provider name referencing an entry in `[[llm.providers]]`.
10///
11/// Using a dedicated type instead of bare `String` makes provider cross-references
12/// explicit in the type system and enables validation at config load time.
13#[derive(Debug, Clone, Default, PartialEq, Eq, Hash, Serialize, Deserialize)]
14#[serde(transparent)]
15pub struct ProviderName(String);
16
17impl ProviderName {
18    /// Create a new `ProviderName` from any string-like value.
19    ///
20    /// An empty string is a sentinel meaning "use the primary provider" and is the
21    /// default value. Check [`is_empty`](Self::is_empty) before using in routing.
22    ///
23    /// # Examples
24    ///
25    /// ```
26    /// use zeph_config::providers::ProviderName;
27    ///
28    /// let name = ProviderName::new("fast");
29    /// assert_eq!(name.as_str(), "fast");
30    /// ```
31    #[must_use]
32    pub fn new(name: impl Into<String>) -> Self {
33        Self(name.into())
34    }
35
36    /// Return `true` when this is the empty sentinel (use primary provider).
37    ///
38    /// # Examples
39    ///
40    /// ```
41    /// use zeph_config::providers::ProviderName;
42    ///
43    /// assert!(ProviderName::default().is_empty());
44    /// assert!(!ProviderName::new("fast").is_empty());
45    /// ```
46    #[must_use]
47    pub fn is_empty(&self) -> bool {
48        self.0.is_empty()
49    }
50
51    /// Return the inner string slice.
52    ///
53    /// # Examples
54    ///
55    /// ```
56    /// use zeph_config::providers::ProviderName;
57    ///
58    /// let name = ProviderName::new("quality");
59    /// assert_eq!(name.as_str(), "quality");
60    /// ```
61    #[must_use]
62    pub fn as_str(&self) -> &str {
63        &self.0
64    }
65}
66
67impl fmt::Display for ProviderName {
68    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
69        self.0.fmt(f)
70    }
71}
72
73impl AsRef<str> for ProviderName {
74    fn as_ref(&self) -> &str {
75        &self.0
76    }
77}
78
79impl std::ops::Deref for ProviderName {
80    type Target = str;
81
82    fn deref(&self) -> &str {
83        &self.0
84    }
85}
86
87impl PartialEq<str> for ProviderName {
88    fn eq(&self, other: &str) -> bool {
89        self.0 == other
90    }
91}
92
93impl PartialEq<&str> for ProviderName {
94    fn eq(&self, other: &&str) -> bool {
95        self.0 == *other
96    }
97}
98
99fn default_response_cache_ttl_secs() -> u64 {
100    3600
101}
102
103fn default_semantic_cache_threshold() -> f32 {
104    0.95
105}
106
107fn default_semantic_cache_max_candidates() -> u32 {
108    10
109}
110
111fn default_router_ema_alpha() -> f64 {
112    0.1
113}
114
115fn default_router_reorder_interval() -> u64 {
116    10
117}
118
119fn default_embedding_model() -> String {
120    "qwen3-embedding".into()
121}
122
123fn default_candle_source() -> String {
124    "huggingface".into()
125}
126
127fn default_chat_template() -> String {
128    "chatml".into()
129}
130
131fn default_candle_device() -> String {
132    "cpu".into()
133}
134
135fn default_temperature() -> f64 {
136    0.7
137}
138
139fn default_max_tokens() -> usize {
140    2048
141}
142
143fn default_seed() -> u64 {
144    42
145}
146
147fn default_repeat_penalty() -> f32 {
148    1.1
149}
150
151fn default_repeat_last_n() -> usize {
152    64
153}
154
155fn default_cascade_quality_threshold() -> f64 {
156    0.5
157}
158
159fn default_cascade_max_escalations() -> u8 {
160    2
161}
162
163fn default_cascade_window_size() -> usize {
164    50
165}
166
167fn default_reputation_decay_factor() -> f64 {
168    0.95
169}
170
171fn default_reputation_weight() -> f64 {
172    0.3
173}
174
175fn default_reputation_min_observations() -> u64 {
176    5
177}
178
179/// Returns the default STT provider name (empty string — auto-detect).
180#[must_use]
181pub fn default_stt_provider() -> String {
182    String::new()
183}
184
185/// Returns the default STT transcription language hint (`"auto"`).
186#[must_use]
187pub fn default_stt_language() -> String {
188    "auto".into()
189}
190
191/// Returns the default embedding model name used by `[llm] embedding_model`.
192#[must_use]
193pub fn get_default_embedding_model() -> String {
194    default_embedding_model()
195}
196
197/// Returns the default response cache TTL in seconds.
198#[must_use]
199pub fn get_default_response_cache_ttl_secs() -> u64 {
200    default_response_cache_ttl_secs()
201}
202
203/// Returns the default EMA alpha for the router latency estimator.
204#[must_use]
205pub fn get_default_router_ema_alpha() -> f64 {
206    default_router_ema_alpha()
207}
208
209/// Returns the default router reorder interval (turns between provider re-ranking).
210#[must_use]
211pub fn get_default_router_reorder_interval() -> u64 {
212    default_router_reorder_interval()
213}
214
215/// LLM provider backend selector.
216///
217/// Used in `[[llm.providers]]` entries as the `type` field.
218///
219/// # Example (TOML)
220///
221/// ```toml
222/// [[llm.providers]]
223/// type = "openai"
224/// model = "gpt-4o"
225/// name = "quality"
226/// ```
227#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize, Serialize)]
228#[serde(rename_all = "lowercase")]
229pub enum ProviderKind {
230    /// Local Ollama server (default base URL: `http://localhost:11434`).
231    Ollama,
232    /// Anthropic Claude API.
233    Claude,
234    /// `OpenAI` API.
235    OpenAi,
236    /// Google Gemini API.
237    Gemini,
238    /// Local Candle inference (CPU/GPU, no external server required).
239    Candle,
240    /// OpenAI-compatible third-party API (e.g. Groq, Together AI, LM Studio).
241    Compatible,
242}
243
244impl ProviderKind {
245    /// Return the lowercase string identifier for this provider kind.
246    ///
247    /// # Examples
248    ///
249    /// ```
250    /// use zeph_config::ProviderKind;
251    ///
252    /// assert_eq!(ProviderKind::Claude.as_str(), "claude");
253    /// assert_eq!(ProviderKind::OpenAi.as_str(), "openai");
254    /// ```
255    #[must_use]
256    pub fn as_str(self) -> &'static str {
257        match self {
258            Self::Ollama => "ollama",
259            Self::Claude => "claude",
260            Self::OpenAi => "openai",
261            Self::Gemini => "gemini",
262            Self::Candle => "candle",
263            Self::Compatible => "compatible",
264        }
265    }
266}
267
268impl std::fmt::Display for ProviderKind {
269    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
270        f.write_str(self.as_str())
271    }
272}
273
274/// LLM configuration, nested under `[llm]` in TOML.
275///
276/// Declares the provider pool and controls routing, embedding, caching, and STT.
277/// All providers are declared in `[[llm.providers]]`; subsystems reference them by
278/// the `name` field using a `*_provider` config key.
279///
280/// # Example (TOML)
281///
282/// ```toml
283/// [[llm.providers]]
284/// name = "fast"
285/// type = "openai"
286/// model = "gpt-4o-mini"
287///
288/// [[llm.providers]]
289/// name = "quality"
290/// type = "claude"
291/// model = "claude-opus-4-5"
292///
293/// [llm]
294/// routing = "none"
295/// embedding_model = "qwen3-embedding"
296/// ```
297#[derive(Debug, Deserialize, Serialize)]
298pub struct LlmConfig {
299    /// Provider pool. First entry is default unless one is marked `default = true`.
300    #[serde(default, skip_serializing_if = "Vec::is_empty")]
301    pub providers: Vec<ProviderEntry>,
302
303    /// Routing strategy for multi-provider configs.
304    #[serde(default, skip_serializing_if = "is_routing_none")]
305    pub routing: LlmRoutingStrategy,
306
307    /// Task-based routes (only used when `routing = "task"`).
308    #[serde(default, skip_serializing_if = "std::collections::HashMap::is_empty")]
309    pub routes: std::collections::HashMap<String, Vec<String>>,
310
311    #[serde(default = "default_embedding_model_opt")]
312    pub embedding_model: String,
313    #[serde(default, skip_serializing_if = "Option::is_none")]
314    pub candle: Option<CandleConfig>,
315    #[serde(default)]
316    pub stt: Option<SttConfig>,
317    #[serde(default)]
318    pub response_cache_enabled: bool,
319    #[serde(default = "default_response_cache_ttl_secs")]
320    pub response_cache_ttl_secs: u64,
321    /// Enable semantic similarity-based response caching. Requires embedding support.
322    #[serde(default)]
323    pub semantic_cache_enabled: bool,
324    /// Cosine similarity threshold for semantic cache hits (0.0–1.0).
325    ///
326    /// Only the highest-scoring candidate above this threshold is returned.
327    /// Lower values produce more cache hits but risk returning less relevant responses.
328    /// Recommended range: 0.92–0.98; default: 0.95.
329    #[serde(default = "default_semantic_cache_threshold")]
330    pub semantic_cache_threshold: f32,
331    /// Maximum cached entries to examine per semantic lookup (SQL `LIMIT` clause in
332    /// `ResponseCache::get_semantic()`). Controls the recall-vs-performance tradeoff:
333    ///
334    /// - **Higher values** (e.g. 50): scan more entries, better chance of finding a
335    ///   semantically similar cached response, but slower queries.
336    /// - **Lower values** (e.g. 5): faster queries, but may miss relevant cached entries
337    ///   when the cache is large.
338    /// - **Default (10)**: balanced middle ground for typical workloads.
339    ///
340    /// Tuning guidance: set to 50+ when recall matters more than latency (e.g. long-running
341    /// sessions with many cached responses); reduce to 5 for low-latency interactive use.
342    /// Env override: `ZEPH_LLM_SEMANTIC_CACHE_MAX_CANDIDATES`.
343    #[serde(default = "default_semantic_cache_max_candidates")]
344    pub semantic_cache_max_candidates: u32,
345    #[serde(default)]
346    pub router_ema_enabled: bool,
347    #[serde(default = "default_router_ema_alpha")]
348    pub router_ema_alpha: f64,
349    #[serde(default = "default_router_reorder_interval")]
350    pub router_reorder_interval: u64,
351    /// Routing configuration for Thompson/Cascade strategies.
352    #[serde(default, skip_serializing_if = "Option::is_none")]
353    pub router: Option<RouterConfig>,
354    /// Provider-specific instruction file to inject into the system prompt.
355    /// Merged with `agent.instruction_files` at startup.
356    #[serde(default, skip_serializing_if = "Option::is_none")]
357    pub instruction_file: Option<std::path::PathBuf>,
358    /// Shorthand model spec for tool-pair summarization and context compaction.
359    /// Format: `ollama/<model>`, `claude[/<model>]`, `openai[/<model>]`, `compatible/<name>`, `candle`.
360    /// Ignored when `[llm.summary_provider]` is set.
361    #[serde(default, skip_serializing_if = "Option::is_none")]
362    pub summary_model: Option<String>,
363    /// Structured provider config for summarization. Takes precedence over `summary_model`.
364    #[serde(default, skip_serializing_if = "Option::is_none")]
365    pub summary_provider: Option<ProviderEntry>,
366
367    /// Complexity triage routing configuration. Required when `routing = "triage"`.
368    #[serde(default, skip_serializing_if = "Option::is_none")]
369    pub complexity_routing: Option<ComplexityRoutingConfig>,
370}
371
372fn default_embedding_model_opt() -> String {
373    default_embedding_model()
374}
375
376#[allow(clippy::trivially_copy_pass_by_ref)]
377fn is_routing_none(s: &LlmRoutingStrategy) -> bool {
378    *s == LlmRoutingStrategy::None
379}
380
381impl LlmConfig {
382    /// Effective provider kind for the primary (first/default) provider in the pool.
383    #[must_use]
384    pub fn effective_provider(&self) -> ProviderKind {
385        self.providers
386            .first()
387            .map_or(ProviderKind::Ollama, |e| e.provider_type)
388    }
389
390    /// Effective base URL for the primary provider.
391    #[must_use]
392    pub fn effective_base_url(&self) -> &str {
393        self.providers
394            .first()
395            .and_then(|e| e.base_url.as_deref())
396            .unwrap_or("http://localhost:11434")
397    }
398
399    /// Effective model for the primary provider.
400    #[must_use]
401    pub fn effective_model(&self) -> &str {
402        self.providers
403            .first()
404            .and_then(|e| e.model.as_deref())
405            .unwrap_or("qwen3:8b")
406    }
407
408    /// Find the provider entry designated for STT.
409    ///
410    /// Resolution priority:
411    /// 1. `[llm.stt].provider` matches `[[llm.providers]].name` and the entry has `stt_model`
412    /// 2. `[llm.stt].provider` is empty — fall through to auto-detect
413    /// 3. First provider with `stt_model` set (auto-detect fallback)
414    /// 4. `None` — STT disabled
415    #[must_use]
416    pub fn stt_provider_entry(&self) -> Option<&ProviderEntry> {
417        let name_hint = self.stt.as_ref().map_or("", |s| s.provider.as_str());
418        if name_hint.is_empty() {
419            self.providers.iter().find(|p| p.stt_model.is_some())
420        } else {
421            self.providers
422                .iter()
423                .find(|p| p.effective_name() == name_hint && p.stt_model.is_some())
424        }
425    }
426
427    /// Validate that the config uses the new `[[llm.providers]]` format.
428    ///
429    /// # Errors
430    ///
431    /// Returns `ConfigError::Validation` when no providers are configured.
432    pub fn check_legacy_format(&self) -> Result<(), crate::error::ConfigError> {
433        Ok(())
434    }
435
436    /// Validate STT config cross-references.
437    ///
438    /// # Errors
439    ///
440    /// Returns `ConfigError::Validation` when the referenced STT provider does not exist.
441    pub fn validate_stt(&self) -> Result<(), crate::error::ConfigError> {
442        use crate::error::ConfigError;
443
444        let Some(stt) = &self.stt else {
445            return Ok(());
446        };
447        if stt.provider.is_empty() {
448            return Ok(());
449        }
450        let found = self
451            .providers
452            .iter()
453            .find(|p| p.effective_name() == stt.provider);
454        match found {
455            None => {
456                return Err(ConfigError::Validation(format!(
457                    "[llm.stt].provider = {:?} does not match any [[llm.providers]] entry",
458                    stt.provider
459                )));
460            }
461            Some(entry) if entry.stt_model.is_none() => {
462                tracing::warn!(
463                    provider = stt.provider,
464                    "[[llm.providers]] entry exists but has no `stt_model` — STT will not be activated"
465                );
466            }
467            _ => {}
468        }
469        Ok(())
470    }
471}
472
473/// Speech-to-text configuration, nested under `[llm.stt]` in TOML.
474///
475/// When set, Zeph uses the referenced provider for voice transcription.
476/// The provider must have an `stt_model` field set in its `[[llm.providers]]` entry.
477///
478/// # Example (TOML)
479///
480/// ```toml
481/// [llm.stt]
482/// provider = "fast"
483/// language = "en"
484/// ```
485#[derive(Debug, Clone, Deserialize, Serialize)]
486pub struct SttConfig {
487    /// Provider name from `[[llm.providers]]`. Empty string means auto-detect first provider
488    /// with `stt_model` set.
489    #[serde(default = "default_stt_provider")]
490    pub provider: String,
491    /// Language hint for transcription (e.g. `"en"`, `"auto"`).
492    #[serde(default = "default_stt_language")]
493    pub language: String,
494}
495
496/// Routing strategy selection for multi-provider routing.
497#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Deserialize, Serialize)]
498#[serde(rename_all = "lowercase")]
499pub enum RouterStrategyConfig {
500    /// Exponential moving average latency-aware ordering.
501    #[default]
502    Ema,
503    /// Thompson Sampling with Beta distributions (persistence-backed).
504    Thompson,
505    /// Cascade routing: try cheapest provider first, escalate on degenerate output.
506    Cascade,
507    /// PILOT: `LinUCB` contextual bandit with online learning and cost-aware reward.
508    Bandit,
509}
510
511/// Agent Stability Index (ASI) configuration.
512///
513/// Tracks per-provider response coherence via a sliding window of response embeddings.
514/// When coherence drops below `coherence_threshold`, the provider's routing prior is
515/// penalized by `penalty_weight`. Disabled by default; session-only (no persistence).
516///
517/// # Known Limitation
518///
519/// ASI embeddings are computed in a background `tokio::spawn` task after the response is
520/// returned to the caller. Under high request rates, the coherence score used for routing
521/// may lag 1–2 responses behind due to this fire-and-forget design. With the default
522/// `window = 5`, this lag is tolerable — coherence is a slow-moving signal.
523#[derive(Debug, Clone, Deserialize, Serialize)]
524pub struct AsiConfig {
525    /// Enable ASI coherence tracking. Default: false.
526    #[serde(default)]
527    pub enabled: bool,
528
529    /// Sliding window size for response embeddings per provider. Default: 5.
530    #[serde(default = "default_asi_window")]
531    pub window: usize,
532
533    /// Coherence score [0.0, 1.0] below which the provider is penalized. Default: 0.7.
534    #[serde(default = "default_asi_coherence_threshold")]
535    pub coherence_threshold: f32,
536
537    /// Penalty weight applied to Thompson beta / EMA score on low coherence. Default: 0.3.
538    ///
539    /// For Thompson, this shifts the beta prior: `beta += penalty_weight * (threshold - coherence)`.
540    /// For EMA, the score is multiplied by `max(0.5, coherence / threshold)`.
541    #[serde(default = "default_asi_penalty_weight")]
542    pub penalty_weight: f32,
543}
544
545fn default_asi_window() -> usize {
546    5
547}
548
549fn default_asi_coherence_threshold() -> f32 {
550    0.7
551}
552
553fn default_asi_penalty_weight() -> f32 {
554    0.3
555}
556
557impl Default for AsiConfig {
558    fn default() -> Self {
559        Self {
560            enabled: false,
561            window: default_asi_window(),
562            coherence_threshold: default_asi_coherence_threshold(),
563            penalty_weight: default_asi_penalty_weight(),
564        }
565    }
566}
567
568/// Routing configuration for multi-provider setups.
569#[derive(Debug, Clone, Deserialize, Serialize)]
570pub struct RouterConfig {
571    /// Routing strategy: `"ema"` (default), `"thompson"`, `"cascade"`, or `"bandit"`.
572    #[serde(default)]
573    pub strategy: RouterStrategyConfig,
574    /// Path for persisting Thompson Sampling state. Defaults to `~/.zeph/router_thompson_state.json`.
575    ///
576    /// # Security
577    ///
578    /// This path is user-controlled. The application writes and reads a JSON file at
579    /// this location. Ensure the path is within a directory that is not world-writable
580    /// (e.g., avoid `/tmp`). The file is created with mode `0o600` on Unix.
581    #[serde(default)]
582    pub thompson_state_path: Option<String>,
583    /// Cascade routing configuration. Only used when `strategy = "cascade"`.
584    #[serde(default)]
585    pub cascade: Option<CascadeConfig>,
586    /// Bayesian reputation scoring configuration (RAPS). Disabled by default.
587    #[serde(default)]
588    pub reputation: Option<ReputationConfig>,
589    /// PILOT bandit routing configuration. Only used when `strategy = "bandit"`.
590    #[serde(default)]
591    pub bandit: Option<BanditConfig>,
592    /// Embedding-based quality gate threshold for Thompson/EMA routing. Default: disabled.
593    ///
594    /// When set, after provider selection, the cosine similarity between the query embedding
595    /// and the response embedding is computed. If below this threshold, the next provider in
596    /// the ordered list is tried. On exhaustion, the best response seen is returned.
597    ///
598    /// Only applies to Thompson and EMA strategies. Cascade uses its own quality classifier.
599    /// Fail-open: embedding errors disable the gate for that request.
600    #[serde(default)]
601    pub quality_gate: Option<f32>,
602    /// Agent Stability Index configuration. Disabled by default.
603    #[serde(default)]
604    pub asi: Option<AsiConfig>,
605    /// Maximum number of concurrent `embed_batch` calls through the router.
606    ///
607    /// Limits simultaneous embedding HTTP requests to prevent provider rate-limiting
608    /// and memory pressure during indexing or high-frequency recall. Default: 4.
609    /// Set to 0 to disable the semaphore (unlimited concurrency).
610    #[serde(default = "default_embed_concurrency")]
611    pub embed_concurrency: usize,
612}
613
614fn default_embed_concurrency() -> usize {
615    4
616}
617
618/// Configuration for Bayesian reputation scoring (RAPS — Reputation-Adjusted Provider Selection).
619///
620/// When enabled, quality outcomes from tool execution shift the routing scores over time,
621/// giving an advantage to providers that consistently produce valid tool arguments.
622///
623/// Default: disabled. Set `enabled = true` to activate.
624#[derive(Debug, Clone, Deserialize, Serialize)]
625pub struct ReputationConfig {
626    /// Enable reputation scoring. Default: false.
627    #[serde(default)]
628    pub enabled: bool,
629    /// Session-level decay factor applied on each load. Range: (0.0, 1.0]. Default: 0.95.
630    /// Lower values make reputation forget faster; 1.0 = no decay.
631    #[serde(default = "default_reputation_decay_factor")]
632    pub decay_factor: f64,
633    /// Weight of reputation in routing score blend. Range: [0.0, 1.0]. Default: 0.3.
634    ///
635    /// **Warning**: values above 0.5 can aggressively suppress low-reputation providers.
636    /// At `weight = 1.0` with `rep_factor = 0.0` (all failures), the routing score
637    /// drops to zero — the provider becomes unreachable for that session. Stick to
638    /// the default (0.3) unless you intentionally want strong reputation gating.
639    #[serde(default = "default_reputation_weight")]
640    pub weight: f64,
641    /// Minimum quality observations before reputation influences routing. Default: 5.
642    #[serde(default = "default_reputation_min_observations")]
643    pub min_observations: u64,
644    /// Path for persisting reputation state. Defaults to `~/.config/zeph/router_reputation_state.json`.
645    #[serde(default)]
646    pub state_path: Option<String>,
647}
648
649/// Configuration for cascade routing (`strategy = "cascade"`).
650///
651/// Cascade routing tries providers in chain order (cheapest first), escalating to
652/// the next provider when the response is classified as degenerate (empty, repetitive,
653/// incoherent). Chain order determines cost order: first provider = cheapest.
654///
655/// # Limitations
656///
657/// The heuristic classifier detects degenerate outputs only, not semantic failures.
658/// Use `classifier_mode = "judge"` for semantic quality gating (adds LLM call cost).
659#[derive(Debug, Clone, Deserialize, Serialize)]
660pub struct CascadeConfig {
661    /// Minimum quality score [0.0, 1.0] to accept a response without escalating.
662    /// Responses scoring below this threshold trigger escalation.
663    #[serde(default = "default_cascade_quality_threshold")]
664    pub quality_threshold: f64,
665
666    /// Maximum number of quality-based escalations per request.
667    /// Network/API errors do not count against this budget.
668    /// Default: 2 (allows up to 3 providers: cheap → mid → expensive).
669    #[serde(default = "default_cascade_max_escalations")]
670    pub max_escalations: u8,
671
672    /// Quality classifier mode: `"heuristic"` (default) or `"judge"`.
673    /// Heuristic is zero-cost but detects only degenerate outputs.
674    /// Judge requires a configured `summary_model` and adds one LLM call per evaluation.
675    #[serde(default)]
676    pub classifier_mode: CascadeClassifierMode,
677
678    /// Rolling quality history window size per provider. Default: 50.
679    #[serde(default = "default_cascade_window_size")]
680    pub window_size: usize,
681
682    /// Maximum cumulative input+output tokens across all escalation levels.
683    /// When exceeded, returns the best-seen response instead of escalating further.
684    /// `None` disables the budget (unbounded escalation cost).
685    #[serde(default)]
686    pub max_cascade_tokens: Option<u32>,
687
688    /// Explicit cost ordering of provider names (cheapest first).
689    /// When set, cascade routing sorts providers by their position in this list before
690    /// trying them. Providers not in the list are appended after listed ones in their
691    /// original chain order. When unset, chain order is used (default behavior).
692    #[serde(default, skip_serializing_if = "Option::is_none")]
693    pub cost_tiers: Option<Vec<String>>,
694}
695
696impl Default for CascadeConfig {
697    fn default() -> Self {
698        Self {
699            quality_threshold: default_cascade_quality_threshold(),
700            max_escalations: default_cascade_max_escalations(),
701            classifier_mode: CascadeClassifierMode::default(),
702            window_size: default_cascade_window_size(),
703            max_cascade_tokens: None,
704            cost_tiers: None,
705        }
706    }
707}
708
709/// Quality classifier mode for cascade routing.
710#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Deserialize, Serialize)]
711#[serde(rename_all = "lowercase")]
712pub enum CascadeClassifierMode {
713    /// Zero-cost heuristic: detects degenerate outputs (empty, repetitive, incoherent).
714    /// Does not detect semantic failures (hallucinations, wrong answers).
715    #[default]
716    Heuristic,
717    /// LLM-based judge: more accurate but adds latency. Falls back to heuristic on failure.
718    /// Requires `summary_model` to be configured.
719    Judge,
720}
721
722fn default_bandit_alpha() -> f32 {
723    1.0
724}
725
726fn default_bandit_dim() -> usize {
727    32
728}
729
730fn default_bandit_cost_weight() -> f32 {
731    0.1
732}
733
734fn default_bandit_decay_factor() -> f32 {
735    1.0
736}
737
738fn default_bandit_embedding_timeout_ms() -> u64 {
739    50
740}
741
742fn default_bandit_cache_size() -> usize {
743    512
744}
745
746/// Configuration for PILOT bandit routing (`strategy = "bandit"`).
747///
748/// PILOT (Provider Intelligence via Learned Online Tuning) uses a `LinUCB` contextual
749/// bandit to learn which provider performs best for a given query context. The feature
750/// vector is derived from the query embedding (first `dim` components, L2-normalised).
751///
752/// **Cold start**: the bandit falls back to Thompson sampling for the first
753/// `10 * num_providers` queries (configurable). After warmup, `LinUCB` takes over.
754///
755/// **Embedding**: an `embedding_provider` must be set for feature vectors. If the embed
756/// call exceeds `embedding_timeout_ms` or fails, the bandit falls back to Thompson/uniform.
757/// Use a local provider (Ollama, Candle) to avoid network latency on the hot path.
758#[derive(Debug, Clone, Deserialize, Serialize)]
759pub struct BanditConfig {
760    /// `LinUCB` exploration parameter. Default: 1.0.
761    /// Higher values increase exploration; lower values favour exploitation.
762    #[serde(default = "default_bandit_alpha")]
763    pub alpha: f32,
764
765    /// Feature vector dimension (first `dim` components of the embedding).
766    ///
767    /// This is simple truncation, not PCA. The first raw embedding dimensions do not
768    /// necessarily capture the most variance. For `OpenAI` `text-embedding-3-*` models,
769    /// consider using the `dimensions` API parameter (Matryoshka embeddings) instead.
770    /// Default: 32.
771    #[serde(default = "default_bandit_dim")]
772    pub dim: usize,
773
774    /// Cost penalty weight in the reward signal: `reward = quality - cost_weight * cost_fraction`.
775    /// Default: 0.1. Increase to penalise expensive providers more aggressively.
776    #[serde(default = "default_bandit_cost_weight")]
777    pub cost_weight: f32,
778
779    /// Session-level decay applied to arm state on startup: `A = I + decay*(A-I)`, `b = decay*b`.
780    /// Values < 1.0 cause re-exploration after provider quality changes. Default: 1.0 (no decay).
781    #[serde(default = "default_bandit_decay_factor")]
782    pub decay_factor: f32,
783
784    /// Provider name from `[[llm.providers]]` used for query embeddings.
785    ///
786    /// SLM recommended: prefer a fast local model (e.g. Ollama `nomic-embed-text`,
787    /// Candle, or `text-embedding-3-small`) — this is called on every bandit request.
788    /// Empty string disables `LinUCB` (bandit always falls back to Thompson/uniform).
789    #[serde(default)]
790    pub embedding_provider: ProviderName,
791
792    /// Hard timeout for the embedding call in milliseconds. Default: 50.
793    /// If exceeded, the request falls back to Thompson/uniform selection.
794    #[serde(default = "default_bandit_embedding_timeout_ms")]
795    pub embedding_timeout_ms: u64,
796
797    /// Maximum cached embeddings (keyed by query text hash). Default: 512.
798    #[serde(default = "default_bandit_cache_size")]
799    pub cache_size: usize,
800
801    /// Path for persisting bandit state. Defaults to `~/.config/zeph/router_bandit_state.json`.
802    ///
803    /// # Security
804    ///
805    /// This path is user-controlled. The file is created with mode `0o600` on Unix.
806    /// Do not place it in world-writable directories.
807    #[serde(default)]
808    pub state_path: Option<String>,
809
810    /// MAR (Memory-Augmented Routing) confidence threshold.
811    ///
812    /// When the top-1 semantic recall score for the current query is >= this value,
813    /// the bandit biases toward cheaper providers (the answer is likely in memory).
814    /// Set to 1.0 to disable MAR. Default: 0.9.
815    #[serde(default = "default_bandit_memory_confidence_threshold")]
816    pub memory_confidence_threshold: f32,
817
818    /// Minimum number of queries before `LinUCB` takes over from Thompson warmup.
819    ///
820    /// When unset or `0`, defaults to `10 × number of providers` (computed at startup).
821    /// Set explicitly to control how long the bandit explores uniformly before
822    /// switching to context-aware routing. Setting `0` preserves the computed default.
823    #[serde(default)]
824    pub warmup_queries: Option<u64>,
825}
826
827fn default_bandit_memory_confidence_threshold() -> f32 {
828    0.9
829}
830
831impl Default for BanditConfig {
832    fn default() -> Self {
833        Self {
834            alpha: default_bandit_alpha(),
835            dim: default_bandit_dim(),
836            cost_weight: default_bandit_cost_weight(),
837            decay_factor: default_bandit_decay_factor(),
838            embedding_provider: ProviderName::default(),
839            embedding_timeout_ms: default_bandit_embedding_timeout_ms(),
840            cache_size: default_bandit_cache_size(),
841            state_path: None,
842            memory_confidence_threshold: default_bandit_memory_confidence_threshold(),
843            warmup_queries: None,
844        }
845    }
846}
847
848#[derive(Debug, Deserialize, Serialize)]
849pub struct CandleConfig {
850    #[serde(default = "default_candle_source")]
851    pub source: String,
852    #[serde(default)]
853    pub local_path: String,
854    #[serde(default)]
855    pub filename: Option<String>,
856    #[serde(default = "default_chat_template")]
857    pub chat_template: String,
858    #[serde(default = "default_candle_device")]
859    pub device: String,
860    #[serde(default)]
861    pub embedding_repo: Option<String>,
862    /// Resolved `HuggingFace` Hub API token for authenticated model downloads.
863    ///
864    /// Must be the **token value** — resolved by the caller before constructing this config.
865    #[serde(default)]
866    pub hf_token: Option<String>,
867    #[serde(default)]
868    pub generation: GenerationParams,
869    /// Maximum seconds to wait for each half of a single inference request.
870    ///
871    /// The timeout is applied **twice** per `chat()` call: once for the channel send
872    /// (waiting for a free slot) and once for the oneshot reply (waiting for the worker
873    /// to finish). The effective maximum wall-clock wait per request is therefore
874    /// `2 × inference_timeout_secs`. CPU inference can be slow; 120s is a conservative
875    /// default for large models, giving up to 240s total before an error is returned.
876    /// Values of 0 are silently promoted to 1 at bootstrap.
877    #[serde(default = "default_inference_timeout_secs")]
878    pub inference_timeout_secs: u64,
879}
880
881fn default_inference_timeout_secs() -> u64 {
882    120
883}
884
885/// Sampling / generation parameters for Candle local inference.
886///
887/// Used inside `[llm.candle.generation]` or a `[[llm.providers]]` Candle entry.
888#[derive(Debug, Clone, Deserialize, Serialize)]
889pub struct GenerationParams {
890    /// Sampling temperature. Higher values produce more creative outputs. Default: `0.7`.
891    #[serde(default = "default_temperature")]
892    pub temperature: f64,
893    /// Nucleus sampling threshold. When set, tokens with cumulative probability above
894    /// this value are excluded. Default: `None` (disabled).
895    #[serde(default)]
896    pub top_p: Option<f64>,
897    /// Top-k sampling. When set, only the top-k most probable tokens are considered.
898    /// Default: `None` (disabled).
899    #[serde(default)]
900    pub top_k: Option<usize>,
901    /// Maximum number of tokens to generate per response. Capped at [`MAX_TOKENS_CAP`].
902    /// Default: `2048`.
903    #[serde(default = "default_max_tokens")]
904    pub max_tokens: usize,
905    /// Random seed for reproducible outputs. Default: `42`.
906    #[serde(default = "default_seed")]
907    pub seed: u64,
908    /// Repetition penalty applied during sampling. Default: `1.1`.
909    #[serde(default = "default_repeat_penalty")]
910    pub repeat_penalty: f32,
911    /// Number of last tokens to consider for the repetition penalty window. Default: `64`.
912    #[serde(default = "default_repeat_last_n")]
913    pub repeat_last_n: usize,
914}
915
916/// Hard upper bound on `GenerationParams::max_tokens` to prevent unbounded generation.
917pub const MAX_TOKENS_CAP: usize = 32768;
918
919impl GenerationParams {
920    /// Returns `max_tokens` clamped to [`MAX_TOKENS_CAP`].
921    ///
922    /// # Examples
923    ///
924    /// ```
925    /// use zeph_config::GenerationParams;
926    ///
927    /// let params = GenerationParams::default();
928    /// assert!(params.capped_max_tokens() <= 32768);
929    /// ```
930    #[must_use]
931    pub fn capped_max_tokens(&self) -> usize {
932        self.max_tokens.min(MAX_TOKENS_CAP)
933    }
934}
935
936impl Default for GenerationParams {
937    fn default() -> Self {
938        Self {
939            temperature: default_temperature(),
940            top_p: None,
941            top_k: None,
942            max_tokens: default_max_tokens(),
943            seed: default_seed(),
944            repeat_penalty: default_repeat_penalty(),
945            repeat_last_n: default_repeat_last_n(),
946        }
947    }
948}
949
950// ─── Unified config types ─────────────────────────────────────────────────────
951
952/// Routing strategy for the `[[llm.providers]]` pool.
953#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Deserialize, Serialize)]
954#[serde(rename_all = "lowercase")]
955pub enum LlmRoutingStrategy {
956    /// Single provider or first-in-pool (default).
957    #[default]
958    None,
959    /// Exponential moving average latency-aware ordering.
960    Ema,
961    /// Thompson Sampling with Beta distributions.
962    Thompson,
963    /// Cascade: try cheapest provider first, escalate on degenerate output.
964    Cascade,
965    /// Task-based routing using `[llm.routes]` map.
966    Task,
967    /// Complexity triage routing: pre-classify each request, delegate to appropriate tier.
968    Triage,
969    /// PILOT: `LinUCB` contextual bandit with online learning and budget-aware reward.
970    Bandit,
971}
972
973fn default_triage_timeout_secs() -> u64 {
974    5
975}
976
977fn default_max_triage_tokens() -> u32 {
978    50
979}
980
981fn default_true() -> bool {
982    true
983}
984
985/// Tier-to-provider name mapping for complexity routing.
986#[derive(Debug, Clone, Default, Deserialize, Serialize)]
987pub struct TierMapping {
988    pub simple: Option<String>,
989    pub medium: Option<String>,
990    pub complex: Option<String>,
991    pub expert: Option<String>,
992}
993
994/// Configuration for complexity-based triage routing (`routing = "triage"`).
995///
996/// When `[llm] routing = "triage"` is set, a cheap triage model classifies each request
997/// and routes it to the appropriate tier provider. Requires at least one tier mapping.
998///
999/// # Example
1000///
1001/// ```toml
1002/// [llm]
1003/// routing = "triage"
1004///
1005/// [llm.complexity_routing]
1006/// triage_provider = "local-fast"
1007///
1008/// [llm.complexity_routing.tiers]
1009/// simple = "local-fast"
1010/// medium = "haiku"
1011/// complex = "sonnet"
1012/// expert = "opus"
1013/// ```
1014#[derive(Debug, Clone, Deserialize, Serialize)]
1015pub struct ComplexityRoutingConfig {
1016    /// Provider name from `[[llm.providers]]` used for triage classification.
1017    #[serde(default)]
1018    pub triage_provider: Option<ProviderName>,
1019
1020    /// Skip triage when all tiers map to the same provider.
1021    #[serde(default = "default_true")]
1022    pub bypass_single_provider: bool,
1023
1024    /// Tier-to-provider name mapping.
1025    #[serde(default)]
1026    pub tiers: TierMapping,
1027
1028    /// Max output tokens for the triage classification call. Default: 50.
1029    #[serde(default = "default_max_triage_tokens")]
1030    pub max_triage_tokens: u32,
1031
1032    /// Timeout in seconds for the triage classification call. Default: 5.
1033    /// On timeout, falls back to the default (first) tier provider.
1034    #[serde(default = "default_triage_timeout_secs")]
1035    pub triage_timeout_secs: u64,
1036
1037    /// Optional fallback strategy when triage misclassifies.
1038    /// Only `"cascade"` is currently supported (Phase 4).
1039    #[serde(default)]
1040    pub fallback_strategy: Option<String>,
1041}
1042
1043impl Default for ComplexityRoutingConfig {
1044    fn default() -> Self {
1045        Self {
1046            triage_provider: None,
1047            bypass_single_provider: true,
1048            tiers: TierMapping::default(),
1049            max_triage_tokens: default_max_triage_tokens(),
1050            triage_timeout_secs: default_triage_timeout_secs(),
1051            fallback_strategy: None,
1052        }
1053    }
1054}
1055
1056/// Inline candle config for use inside `ProviderEntry`.
1057/// Re-uses the generation params from `CandleConfig`.
1058#[derive(Debug, Clone, Deserialize, Serialize)]
1059pub struct CandleInlineConfig {
1060    #[serde(default = "default_candle_source")]
1061    pub source: String,
1062    #[serde(default)]
1063    pub local_path: String,
1064    #[serde(default)]
1065    pub filename: Option<String>,
1066    #[serde(default = "default_chat_template")]
1067    pub chat_template: String,
1068    #[serde(default = "default_candle_device")]
1069    pub device: String,
1070    #[serde(default)]
1071    pub embedding_repo: Option<String>,
1072    /// Resolved `HuggingFace` Hub API token for authenticated model downloads.
1073    #[serde(default)]
1074    pub hf_token: Option<String>,
1075    #[serde(default)]
1076    pub generation: GenerationParams,
1077    /// Maximum wall-clock seconds to wait for a single inference request.
1078    ///
1079    /// Effective timeout is `2 × inference_timeout_secs` (send + recv each have this budget).
1080    /// CPU inference can be slow; 120s is a conservative default. Floored at 1s.
1081    #[serde(default = "default_inference_timeout_secs")]
1082    pub inference_timeout_secs: u64,
1083}
1084
1085impl Default for CandleInlineConfig {
1086    fn default() -> Self {
1087        Self {
1088            source: default_candle_source(),
1089            local_path: String::new(),
1090            filename: None,
1091            chat_template: default_chat_template(),
1092            device: default_candle_device(),
1093            embedding_repo: None,
1094            hf_token: None,
1095            generation: GenerationParams::default(),
1096            inference_timeout_secs: default_inference_timeout_secs(),
1097        }
1098    }
1099}
1100
1101/// Unified provider entry: one struct replaces `CloudLlmConfig`, `OpenAiConfig`,
1102/// `GeminiConfig`, `OllamaConfig`, `CompatibleConfig`, and `OrchestratorProviderConfig`.
1103///
1104/// Provider-specific fields use `#[serde(default)]` and are ignored by backends
1105/// that do not use them (flat-union pattern).
1106#[derive(Debug, Clone, Deserialize, Serialize)]
1107#[allow(clippy::struct_excessive_bools)]
1108pub struct ProviderEntry {
1109    /// Required: provider backend type.
1110    #[serde(rename = "type")]
1111    pub provider_type: ProviderKind,
1112
1113    /// Optional name for multi-provider configs. Auto-generated from type if absent.
1114    #[serde(default)]
1115    pub name: Option<String>,
1116
1117    /// Model identifier. Required for most types.
1118    #[serde(default)]
1119    pub model: Option<String>,
1120
1121    /// API base URL. Each type has its own default.
1122    #[serde(default)]
1123    pub base_url: Option<String>,
1124
1125    /// Max output tokens.
1126    #[serde(default)]
1127    pub max_tokens: Option<u32>,
1128
1129    /// Embedding model. When set, this provider supports `embed()` calls.
1130    #[serde(default)]
1131    pub embedding_model: Option<String>,
1132
1133    /// STT model. When set, this provider supports speech-to-text via the Whisper API or
1134    /// Candle-local inference.
1135    #[serde(default)]
1136    pub stt_model: Option<String>,
1137
1138    /// Mark this entry as the embedding provider (handles `embed()` calls).
1139    #[serde(default)]
1140    pub embed: bool,
1141
1142    /// Mark this entry as the default chat provider (overrides position-based default).
1143    #[serde(default)]
1144    pub default: bool,
1145
1146    // --- Claude-specific ---
1147    #[serde(default)]
1148    pub thinking: Option<ThinkingConfig>,
1149    #[serde(default)]
1150    pub server_compaction: bool,
1151    #[serde(default)]
1152    pub enable_extended_context: bool,
1153
1154    // --- OpenAI-specific ---
1155    #[serde(default)]
1156    pub reasoning_effort: Option<String>,
1157
1158    // --- Gemini-specific ---
1159    #[serde(default)]
1160    pub thinking_level: Option<GeminiThinkingLevel>,
1161    #[serde(default)]
1162    pub thinking_budget: Option<i32>,
1163    #[serde(default)]
1164    pub include_thoughts: Option<bool>,
1165
1166    // --- Compatible-specific: optional inline api_key ---
1167    #[serde(default)]
1168    pub api_key: Option<String>,
1169
1170    // --- Candle-specific ---
1171    #[serde(default)]
1172    pub candle: Option<CandleInlineConfig>,
1173
1174    // --- Vision ---
1175    #[serde(default)]
1176    pub vision_model: Option<String>,
1177
1178    /// Provider-specific instruction file.
1179    #[serde(default)]
1180    pub instruction_file: Option<std::path::PathBuf>,
1181}
1182
1183impl Default for ProviderEntry {
1184    fn default() -> Self {
1185        Self {
1186            provider_type: ProviderKind::Ollama,
1187            name: None,
1188            model: None,
1189            base_url: None,
1190            max_tokens: None,
1191            embedding_model: None,
1192            stt_model: None,
1193            embed: false,
1194            default: false,
1195            thinking: None,
1196            server_compaction: false,
1197            enable_extended_context: false,
1198            reasoning_effort: None,
1199            thinking_level: None,
1200            thinking_budget: None,
1201            include_thoughts: None,
1202            api_key: None,
1203            candle: None,
1204            vision_model: None,
1205            instruction_file: None,
1206        }
1207    }
1208}
1209
1210impl ProviderEntry {
1211    /// Resolve the effective name: explicit `name` field or type string.
1212    #[must_use]
1213    pub fn effective_name(&self) -> String {
1214        self.name
1215            .clone()
1216            .unwrap_or_else(|| self.provider_type.as_str().to_owned())
1217    }
1218
1219    /// Resolve the effective model: explicit `model` field or the provider-type default.
1220    ///
1221    /// Defaults mirror those used in `build_provider_from_entry` so that `runtime.model_name`
1222    /// always reflects the actual model being used rather than the provider type string.
1223    #[must_use]
1224    pub fn effective_model(&self) -> String {
1225        if let Some(ref m) = self.model {
1226            return m.clone();
1227        }
1228        match self.provider_type {
1229            ProviderKind::Ollama => "qwen3:8b".to_owned(),
1230            ProviderKind::Claude => "claude-haiku-4-5-20251001".to_owned(),
1231            ProviderKind::OpenAi => "gpt-4o-mini".to_owned(),
1232            ProviderKind::Gemini => "gemini-2.0-flash".to_owned(),
1233            ProviderKind::Compatible | ProviderKind::Candle => String::new(),
1234        }
1235    }
1236
1237    /// Validate this entry for cross-field consistency.
1238    ///
1239    /// # Errors
1240    ///
1241    /// Returns `ConfigError` when a fatal invariant is violated (e.g. compatible provider
1242    /// without a name).
1243    pub fn validate(&self) -> Result<(), crate::error::ConfigError> {
1244        use crate::error::ConfigError;
1245
1246        // B2: compatible provider MUST have name set.
1247        if self.provider_type == ProviderKind::Compatible && self.name.is_none() {
1248            return Err(ConfigError::Validation(
1249                "[[llm.providers]] entry with type=\"compatible\" must set `name`".into(),
1250            ));
1251        }
1252
1253        // B1: warn on irrelevant fields.
1254        match self.provider_type {
1255            ProviderKind::Ollama => {
1256                if self.thinking.is_some() {
1257                    tracing::warn!(
1258                        provider = self.effective_name(),
1259                        "field `thinking` is only used by Claude providers"
1260                    );
1261                }
1262                if self.reasoning_effort.is_some() {
1263                    tracing::warn!(
1264                        provider = self.effective_name(),
1265                        "field `reasoning_effort` is only used by OpenAI providers"
1266                    );
1267                }
1268                if self.thinking_level.is_some() || self.thinking_budget.is_some() {
1269                    tracing::warn!(
1270                        provider = self.effective_name(),
1271                        "fields `thinking_level`/`thinking_budget` are only used by Gemini providers"
1272                    );
1273                }
1274            }
1275            ProviderKind::Claude => {
1276                if self.reasoning_effort.is_some() {
1277                    tracing::warn!(
1278                        provider = self.effective_name(),
1279                        "field `reasoning_effort` is only used by OpenAI providers"
1280                    );
1281                }
1282                if self.thinking_level.is_some() || self.thinking_budget.is_some() {
1283                    tracing::warn!(
1284                        provider = self.effective_name(),
1285                        "fields `thinking_level`/`thinking_budget` are only used by Gemini providers"
1286                    );
1287                }
1288            }
1289            ProviderKind::OpenAi => {
1290                if self.thinking.is_some() {
1291                    tracing::warn!(
1292                        provider = self.effective_name(),
1293                        "field `thinking` is only used by Claude providers"
1294                    );
1295                }
1296                if self.thinking_level.is_some() || self.thinking_budget.is_some() {
1297                    tracing::warn!(
1298                        provider = self.effective_name(),
1299                        "fields `thinking_level`/`thinking_budget` are only used by Gemini providers"
1300                    );
1301                }
1302            }
1303            ProviderKind::Gemini => {
1304                if self.thinking.is_some() {
1305                    tracing::warn!(
1306                        provider = self.effective_name(),
1307                        "field `thinking` is only used by Claude providers"
1308                    );
1309                }
1310                if self.reasoning_effort.is_some() {
1311                    tracing::warn!(
1312                        provider = self.effective_name(),
1313                        "field `reasoning_effort` is only used by OpenAI providers"
1314                    );
1315                }
1316            }
1317            _ => {}
1318        }
1319
1320        // W6: Candle STT-only provider (stt_model set, no model) is valid — no warning needed.
1321        // Warn if Ollama has stt_model set (Ollama does not support Whisper API).
1322        if self.stt_model.is_some() && self.provider_type == ProviderKind::Ollama {
1323            tracing::warn!(
1324                provider = self.effective_name(),
1325                "field `stt_model` is set on an Ollama provider; Ollama does not support the \
1326                 Whisper STT API — use OpenAI, compatible, or candle instead"
1327            );
1328        }
1329
1330        Ok(())
1331    }
1332}
1333
1334/// Validate a pool of `ProviderEntry` items.
1335///
1336/// # Errors
1337///
1338/// Returns `ConfigError` for fatal validation failures:
1339/// - Empty pool
1340/// - Duplicate names
1341/// - Multiple entries marked `default = true`
1342/// - Individual entry validation errors
1343pub fn validate_pool(entries: &[ProviderEntry]) -> Result<(), crate::error::ConfigError> {
1344    use crate::error::ConfigError;
1345    use std::collections::HashSet;
1346
1347    if entries.is_empty() {
1348        return Err(ConfigError::Validation(
1349            "at least one LLM provider must be configured in [[llm.providers]]".into(),
1350        ));
1351    }
1352
1353    let default_count = entries.iter().filter(|e| e.default).count();
1354    if default_count > 1 {
1355        return Err(ConfigError::Validation(
1356            "only one [[llm.providers]] entry can be marked `default = true`".into(),
1357        ));
1358    }
1359
1360    let mut seen_names: HashSet<String> = HashSet::new();
1361    for entry in entries {
1362        let name = entry.effective_name();
1363        if !seen_names.insert(name.clone()) {
1364            return Err(ConfigError::Validation(format!(
1365                "duplicate provider name \"{name}\" in [[llm.providers]]"
1366            )));
1367        }
1368        entry.validate()?;
1369    }
1370
1371    Ok(())
1372}
1373
1374#[cfg(test)]
1375mod tests {
1376    use super::*;
1377
1378    fn ollama_entry() -> ProviderEntry {
1379        ProviderEntry {
1380            provider_type: ProviderKind::Ollama,
1381            name: Some("ollama".into()),
1382            model: Some("qwen3:8b".into()),
1383            ..Default::default()
1384        }
1385    }
1386
1387    fn claude_entry() -> ProviderEntry {
1388        ProviderEntry {
1389            provider_type: ProviderKind::Claude,
1390            name: Some("claude".into()),
1391            model: Some("claude-sonnet-4-6".into()),
1392            max_tokens: Some(8192),
1393            ..Default::default()
1394        }
1395    }
1396
1397    // ─── ProviderEntry::validate ─────────────────────────────────────────────
1398
1399    #[test]
1400    fn validate_ollama_valid() {
1401        assert!(ollama_entry().validate().is_ok());
1402    }
1403
1404    #[test]
1405    fn validate_claude_valid() {
1406        assert!(claude_entry().validate().is_ok());
1407    }
1408
1409    #[test]
1410    fn validate_compatible_without_name_errors() {
1411        let entry = ProviderEntry {
1412            provider_type: ProviderKind::Compatible,
1413            name: None,
1414            ..Default::default()
1415        };
1416        let err = entry.validate().unwrap_err();
1417        assert!(
1418            err.to_string().contains("compatible"),
1419            "error should mention compatible: {err}"
1420        );
1421    }
1422
1423    #[test]
1424    fn validate_compatible_with_name_ok() {
1425        let entry = ProviderEntry {
1426            provider_type: ProviderKind::Compatible,
1427            name: Some("my-proxy".into()),
1428            base_url: Some("http://localhost:8080".into()),
1429            model: Some("gpt-4o".into()),
1430            max_tokens: Some(4096),
1431            ..Default::default()
1432        };
1433        assert!(entry.validate().is_ok());
1434    }
1435
1436    #[test]
1437    fn validate_openai_valid() {
1438        let entry = ProviderEntry {
1439            provider_type: ProviderKind::OpenAi,
1440            name: Some("openai".into()),
1441            model: Some("gpt-4o".into()),
1442            max_tokens: Some(4096),
1443            ..Default::default()
1444        };
1445        assert!(entry.validate().is_ok());
1446    }
1447
1448    #[test]
1449    fn validate_gemini_valid() {
1450        let entry = ProviderEntry {
1451            provider_type: ProviderKind::Gemini,
1452            name: Some("gemini".into()),
1453            model: Some("gemini-2.0-flash".into()),
1454            ..Default::default()
1455        };
1456        assert!(entry.validate().is_ok());
1457    }
1458
1459    // ─── validate_pool ───────────────────────────────────────────────────────
1460
1461    #[test]
1462    fn validate_pool_empty_errors() {
1463        let err = validate_pool(&[]).unwrap_err();
1464        assert!(err.to_string().contains("at least one"), "{err}");
1465    }
1466
1467    #[test]
1468    fn validate_pool_single_entry_ok() {
1469        assert!(validate_pool(&[ollama_entry()]).is_ok());
1470    }
1471
1472    #[test]
1473    fn validate_pool_duplicate_names_errors() {
1474        let a = ollama_entry();
1475        let b = ollama_entry(); // same effective name "ollama"
1476        let err = validate_pool(&[a, b]).unwrap_err();
1477        assert!(err.to_string().contains("duplicate"), "{err}");
1478    }
1479
1480    #[test]
1481    fn validate_pool_multiple_defaults_errors() {
1482        let mut a = ollama_entry();
1483        let mut b = claude_entry();
1484        a.default = true;
1485        b.default = true;
1486        let err = validate_pool(&[a, b]).unwrap_err();
1487        assert!(err.to_string().contains("default"), "{err}");
1488    }
1489
1490    #[test]
1491    fn validate_pool_two_different_providers_ok() {
1492        assert!(validate_pool(&[ollama_entry(), claude_entry()]).is_ok());
1493    }
1494
1495    #[test]
1496    fn validate_pool_propagates_entry_error() {
1497        let bad = ProviderEntry {
1498            provider_type: ProviderKind::Compatible,
1499            name: None, // invalid: compatible without name
1500            ..Default::default()
1501        };
1502        assert!(validate_pool(&[bad]).is_err());
1503    }
1504
1505    // ─── ProviderEntry::effective_model ──────────────────────────────────────
1506
1507    #[test]
1508    fn effective_model_returns_explicit_when_set() {
1509        let entry = ProviderEntry {
1510            provider_type: ProviderKind::Claude,
1511            model: Some("claude-sonnet-4-6".into()),
1512            ..Default::default()
1513        };
1514        assert_eq!(entry.effective_model(), "claude-sonnet-4-6");
1515    }
1516
1517    #[test]
1518    fn effective_model_ollama_default_when_none() {
1519        let entry = ProviderEntry {
1520            provider_type: ProviderKind::Ollama,
1521            model: None,
1522            ..Default::default()
1523        };
1524        assert_eq!(entry.effective_model(), "qwen3:8b");
1525    }
1526
1527    #[test]
1528    fn effective_model_claude_default_when_none() {
1529        let entry = ProviderEntry {
1530            provider_type: ProviderKind::Claude,
1531            model: None,
1532            ..Default::default()
1533        };
1534        assert_eq!(entry.effective_model(), "claude-haiku-4-5-20251001");
1535    }
1536
1537    #[test]
1538    fn effective_model_openai_default_when_none() {
1539        let entry = ProviderEntry {
1540            provider_type: ProviderKind::OpenAi,
1541            model: None,
1542            ..Default::default()
1543        };
1544        assert_eq!(entry.effective_model(), "gpt-4o-mini");
1545    }
1546
1547    #[test]
1548    fn effective_model_gemini_default_when_none() {
1549        let entry = ProviderEntry {
1550            provider_type: ProviderKind::Gemini,
1551            model: None,
1552            ..Default::default()
1553        };
1554        assert_eq!(entry.effective_model(), "gemini-2.0-flash");
1555    }
1556
1557    // ─── LlmConfig::check_legacy_format ──────────────────────────────────────
1558
1559    // Parse a complete TOML snippet that includes the [llm] header.
1560    fn parse_llm(toml: &str) -> LlmConfig {
1561        #[derive(serde::Deserialize)]
1562        struct Wrapper {
1563            llm: LlmConfig,
1564        }
1565        toml::from_str::<Wrapper>(toml).unwrap().llm
1566    }
1567
1568    #[test]
1569    fn check_legacy_format_new_format_ok() {
1570        let cfg = parse_llm(
1571            r#"
1572[llm]
1573
1574[[llm.providers]]
1575type = "ollama"
1576model = "qwen3:8b"
1577"#,
1578        );
1579        assert!(cfg.check_legacy_format().is_ok());
1580    }
1581
1582    #[test]
1583    fn check_legacy_format_empty_providers_no_legacy_ok() {
1584        // No providers, no legacy fields — passes (empty [llm] is acceptable here)
1585        let cfg = parse_llm("[llm]\n");
1586        assert!(cfg.check_legacy_format().is_ok());
1587    }
1588
1589    // ─── LlmConfig::effective_* helpers ──────────────────────────────────────
1590
1591    #[test]
1592    fn effective_provider_falls_back_to_ollama_when_no_providers() {
1593        let cfg = parse_llm("[llm]\n");
1594        assert_eq!(cfg.effective_provider(), ProviderKind::Ollama);
1595    }
1596
1597    #[test]
1598    fn effective_provider_reads_from_providers_first() {
1599        let cfg = parse_llm(
1600            r#"
1601[llm]
1602
1603[[llm.providers]]
1604type = "claude"
1605model = "claude-sonnet-4-6"
1606"#,
1607        );
1608        assert_eq!(cfg.effective_provider(), ProviderKind::Claude);
1609    }
1610
1611    #[test]
1612    fn effective_model_reads_from_providers_first() {
1613        let cfg = parse_llm(
1614            r#"
1615[llm]
1616
1617[[llm.providers]]
1618type = "ollama"
1619model = "qwen3:8b"
1620"#,
1621        );
1622        assert_eq!(cfg.effective_model(), "qwen3:8b");
1623    }
1624
1625    #[test]
1626    fn effective_base_url_default_when_absent() {
1627        let cfg = parse_llm("[llm]\n");
1628        assert_eq!(cfg.effective_base_url(), "http://localhost:11434");
1629    }
1630
1631    #[test]
1632    fn effective_base_url_from_providers_entry() {
1633        let cfg = parse_llm(
1634            r#"
1635[llm]
1636
1637[[llm.providers]]
1638type = "ollama"
1639base_url = "http://myhost:11434"
1640"#,
1641        );
1642        assert_eq!(cfg.effective_base_url(), "http://myhost:11434");
1643    }
1644
1645    // ─── ComplexityRoutingConfig / LlmRoutingStrategy::Triage TOML parsing ──
1646
1647    #[test]
1648    fn complexity_routing_defaults() {
1649        let cr = ComplexityRoutingConfig::default();
1650        assert!(
1651            cr.bypass_single_provider,
1652            "bypass_single_provider must default to true"
1653        );
1654        assert_eq!(cr.triage_timeout_secs, 5);
1655        assert_eq!(cr.max_triage_tokens, 50);
1656        assert!(cr.triage_provider.is_none());
1657        assert!(cr.tiers.simple.is_none());
1658    }
1659
1660    #[test]
1661    fn complexity_routing_toml_round_trip() {
1662        let cfg = parse_llm(
1663            r#"
1664[llm]
1665routing = "triage"
1666
1667[llm.complexity_routing]
1668triage_provider = "fast"
1669bypass_single_provider = false
1670triage_timeout_secs = 10
1671max_triage_tokens = 100
1672
1673[llm.complexity_routing.tiers]
1674simple = "fast"
1675medium = "medium"
1676complex = "large"
1677expert = "opus"
1678"#,
1679        );
1680        assert!(matches!(cfg.routing, LlmRoutingStrategy::Triage));
1681        let cr = cfg
1682            .complexity_routing
1683            .expect("complexity_routing must be present");
1684        assert_eq!(cr.triage_provider.as_deref(), Some("fast"));
1685        assert!(!cr.bypass_single_provider);
1686        assert_eq!(cr.triage_timeout_secs, 10);
1687        assert_eq!(cr.max_triage_tokens, 100);
1688        assert_eq!(cr.tiers.simple.as_deref(), Some("fast"));
1689        assert_eq!(cr.tiers.medium.as_deref(), Some("medium"));
1690        assert_eq!(cr.tiers.complex.as_deref(), Some("large"));
1691        assert_eq!(cr.tiers.expert.as_deref(), Some("opus"));
1692    }
1693
1694    #[test]
1695    fn complexity_routing_partial_tiers_toml() {
1696        // Only simple + complex configured; medium and expert are None.
1697        let cfg = parse_llm(
1698            r#"
1699[llm]
1700routing = "triage"
1701
1702[llm.complexity_routing.tiers]
1703simple = "haiku"
1704complex = "sonnet"
1705"#,
1706        );
1707        let cr = cfg
1708            .complexity_routing
1709            .expect("complexity_routing must be present");
1710        assert_eq!(cr.tiers.simple.as_deref(), Some("haiku"));
1711        assert!(cr.tiers.medium.is_none());
1712        assert_eq!(cr.tiers.complex.as_deref(), Some("sonnet"));
1713        assert!(cr.tiers.expert.is_none());
1714        // Defaults still applied.
1715        assert!(cr.bypass_single_provider);
1716        assert_eq!(cr.triage_timeout_secs, 5);
1717    }
1718
1719    #[test]
1720    fn routing_strategy_triage_deserialized() {
1721        let cfg = parse_llm(
1722            r#"
1723[llm]
1724routing = "triage"
1725"#,
1726        );
1727        assert!(matches!(cfg.routing, LlmRoutingStrategy::Triage));
1728    }
1729
1730    // ─── stt_provider_entry ───────────────────────────────────────────────────
1731
1732    #[test]
1733    fn stt_provider_entry_by_name_match() {
1734        let cfg = parse_llm(
1735            r#"
1736[llm]
1737
1738[[llm.providers]]
1739type = "openai"
1740name = "quality"
1741model = "gpt-5.4"
1742stt_model = "gpt-4o-mini-transcribe"
1743
1744[llm.stt]
1745provider = "quality"
1746"#,
1747        );
1748        let entry = cfg.stt_provider_entry().expect("should find stt provider");
1749        assert_eq!(entry.effective_name(), "quality");
1750        assert_eq!(entry.stt_model.as_deref(), Some("gpt-4o-mini-transcribe"));
1751    }
1752
1753    #[test]
1754    fn stt_provider_entry_auto_detect_when_provider_empty() {
1755        let cfg = parse_llm(
1756            r#"
1757[llm]
1758
1759[[llm.providers]]
1760type = "openai"
1761name = "openai-stt"
1762stt_model = "whisper-1"
1763
1764[llm.stt]
1765provider = ""
1766"#,
1767        );
1768        let entry = cfg.stt_provider_entry().expect("should auto-detect");
1769        assert_eq!(entry.effective_name(), "openai-stt");
1770    }
1771
1772    #[test]
1773    fn stt_provider_entry_auto_detect_no_stt_section() {
1774        let cfg = parse_llm(
1775            r#"
1776[llm]
1777
1778[[llm.providers]]
1779type = "openai"
1780name = "openai-stt"
1781stt_model = "whisper-1"
1782"#,
1783        );
1784        // No [llm.stt] section — should still find first provider with stt_model.
1785        let entry = cfg.stt_provider_entry().expect("should auto-detect");
1786        assert_eq!(entry.effective_name(), "openai-stt");
1787    }
1788
1789    #[test]
1790    fn stt_provider_entry_none_when_no_stt_model() {
1791        let cfg = parse_llm(
1792            r#"
1793[llm]
1794
1795[[llm.providers]]
1796type = "openai"
1797name = "quality"
1798model = "gpt-5.4"
1799"#,
1800        );
1801        assert!(cfg.stt_provider_entry().is_none());
1802    }
1803
1804    #[test]
1805    fn stt_provider_entry_name_mismatch_falls_back_to_none() {
1806        // Named provider exists but has no stt_model; another unnamed has stt_model.
1807        let cfg = parse_llm(
1808            r#"
1809[llm]
1810
1811[[llm.providers]]
1812type = "openai"
1813name = "quality"
1814model = "gpt-5.4"
1815
1816[[llm.providers]]
1817type = "openai"
1818name = "openai-stt"
1819stt_model = "whisper-1"
1820
1821[llm.stt]
1822provider = "quality"
1823"#,
1824        );
1825        // "quality" has no stt_model — returns None for name-based lookup.
1826        assert!(cfg.stt_provider_entry().is_none());
1827    }
1828
1829    #[test]
1830    fn stt_config_deserializes_new_slim_format() {
1831        let cfg = parse_llm(
1832            r#"
1833[llm]
1834
1835[[llm.providers]]
1836type = "openai"
1837name = "quality"
1838stt_model = "whisper-1"
1839
1840[llm.stt]
1841provider = "quality"
1842language = "en"
1843"#,
1844        );
1845        let stt = cfg.stt.as_ref().expect("stt section present");
1846        assert_eq!(stt.provider, "quality");
1847        assert_eq!(stt.language, "en");
1848    }
1849
1850    #[test]
1851    fn stt_config_default_provider_is_empty() {
1852        // Verify that W4 fix: default_stt_provider() returns "" not "whisper".
1853        assert_eq!(default_stt_provider(), "");
1854    }
1855
1856    #[test]
1857    fn validate_stt_missing_provider_ok() {
1858        let cfg = parse_llm("[llm]\n");
1859        assert!(cfg.validate_stt().is_ok());
1860    }
1861
1862    #[test]
1863    fn validate_stt_valid_reference() {
1864        let cfg = parse_llm(
1865            r#"
1866[llm]
1867
1868[[llm.providers]]
1869type = "openai"
1870name = "quality"
1871stt_model = "whisper-1"
1872
1873[llm.stt]
1874provider = "quality"
1875"#,
1876        );
1877        assert!(cfg.validate_stt().is_ok());
1878    }
1879
1880    #[test]
1881    fn validate_stt_nonexistent_provider_errors() {
1882        let cfg = parse_llm(
1883            r#"
1884[llm]
1885
1886[[llm.providers]]
1887type = "openai"
1888name = "quality"
1889model = "gpt-5.4"
1890
1891[llm.stt]
1892provider = "nonexistent"
1893"#,
1894        );
1895        assert!(cfg.validate_stt().is_err());
1896    }
1897
1898    #[test]
1899    fn validate_stt_provider_exists_but_no_stt_model_returns_ok_with_warn() {
1900        // MEDIUM: provider is found but has no stt_model — should return Ok (warn path, not error).
1901        let cfg = parse_llm(
1902            r#"
1903[llm]
1904
1905[[llm.providers]]
1906type = "openai"
1907name = "quality"
1908model = "gpt-5.4"
1909
1910[llm.stt]
1911provider = "quality"
1912"#,
1913        );
1914        // validate_stt must succeed (only a tracing::warn is emitted — not an error).
1915        assert!(cfg.validate_stt().is_ok());
1916        // stt_provider_entry must return None because no stt_model is set.
1917        assert!(
1918            cfg.stt_provider_entry().is_none(),
1919            "stt_provider_entry must be None when provider has no stt_model"
1920        );
1921    }
1922
1923    // ─── BanditConfig::warmup_queries deserialization ─────────────────────────
1924
1925    #[test]
1926    fn bandit_warmup_queries_explicit_value_is_deserialized() {
1927        let cfg = parse_llm(
1928            r#"
1929[llm]
1930
1931[llm.router]
1932strategy = "bandit"
1933
1934[llm.router.bandit]
1935warmup_queries = 50
1936"#,
1937        );
1938        let bandit = cfg
1939            .router
1940            .expect("router section must be present")
1941            .bandit
1942            .expect("bandit section must be present");
1943        assert_eq!(
1944            bandit.warmup_queries,
1945            Some(50),
1946            "warmup_queries = 50 must deserialize to Some(50)"
1947        );
1948    }
1949
1950    #[test]
1951    fn bandit_warmup_queries_explicit_null_is_none() {
1952        // Explicitly writing the field as absent: field simply not present is
1953        // equivalent due to #[serde(default)]. Test that an explicit 0 is Some(0).
1954        let cfg = parse_llm(
1955            r#"
1956[llm]
1957
1958[llm.router]
1959strategy = "bandit"
1960
1961[llm.router.bandit]
1962warmup_queries = 0
1963"#,
1964        );
1965        let bandit = cfg
1966            .router
1967            .expect("router section must be present")
1968            .bandit
1969            .expect("bandit section must be present");
1970        // 0 is a valid explicit value — it means "preserve computed default".
1971        assert_eq!(
1972            bandit.warmup_queries,
1973            Some(0),
1974            "warmup_queries = 0 must deserialize to Some(0)"
1975        );
1976    }
1977
1978    #[test]
1979    fn bandit_warmup_queries_missing_field_defaults_to_none() {
1980        // When warmup_queries is omitted entirely, #[serde(default)] must produce None.
1981        let cfg = parse_llm(
1982            r#"
1983[llm]
1984
1985[llm.router]
1986strategy = "bandit"
1987
1988[llm.router.bandit]
1989alpha = 1.5
1990"#,
1991        );
1992        let bandit = cfg
1993            .router
1994            .expect("router section must be present")
1995            .bandit
1996            .expect("bandit section must be present");
1997        assert_eq!(
1998            bandit.warmup_queries, None,
1999            "omitted warmup_queries must default to None"
2000        );
2001    }
2002
2003    #[test]
2004    fn provider_name_new_and_as_str() {
2005        let n = ProviderName::new("fast");
2006        assert_eq!(n.as_str(), "fast");
2007        assert!(!n.is_empty());
2008    }
2009
2010    #[test]
2011    fn provider_name_default_is_empty() {
2012        let n = ProviderName::default();
2013        assert!(n.is_empty());
2014        assert_eq!(n.as_str(), "");
2015    }
2016
2017    #[test]
2018    fn provider_name_deref_to_str() {
2019        let n = ProviderName::new("quality");
2020        let s: &str = &n;
2021        assert_eq!(s, "quality");
2022    }
2023
2024    #[test]
2025    fn provider_name_partial_eq_str() {
2026        let n = ProviderName::new("fast");
2027        assert_eq!(n, "fast");
2028        assert_ne!(n, "slow");
2029    }
2030
2031    #[test]
2032    fn provider_name_serde_roundtrip() {
2033        let n = ProviderName::new("my-provider");
2034        let json = serde_json::to_string(&n).expect("serialize");
2035        assert_eq!(json, "\"my-provider\"");
2036        let back: ProviderName = serde_json::from_str(&json).expect("deserialize");
2037        assert_eq!(back, n);
2038    }
2039
2040    #[test]
2041    fn provider_name_serde_empty_roundtrip() {
2042        let n = ProviderName::default();
2043        let json = serde_json::to_string(&n).expect("serialize");
2044        assert_eq!(json, "\"\"");
2045        let back: ProviderName = serde_json::from_str(&json).expect("deserialize");
2046        assert_eq!(back, n);
2047        assert!(back.is_empty());
2048    }
2049}