1use std::fmt;
5
6use serde::{Deserialize, Serialize};
7use zeph_llm::{CacheTtl, GeminiThinkingLevel, ThinkingConfig};
8
9#[derive(Debug, Clone, Default, PartialEq, Eq, Hash, Serialize, Deserialize)]
14#[serde(transparent)]
15pub struct ProviderName(String);
16
17impl ProviderName {
18 #[must_use]
32 pub fn new(name: impl Into<String>) -> Self {
33 Self(name.into())
34 }
35
36 #[must_use]
47 pub fn is_empty(&self) -> bool {
48 self.0.is_empty()
49 }
50
51 #[must_use]
62 pub fn as_str(&self) -> &str {
63 &self.0
64 }
65}
66
67impl fmt::Display for ProviderName {
68 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
69 self.0.fmt(f)
70 }
71}
72
73impl AsRef<str> for ProviderName {
74 fn as_ref(&self) -> &str {
75 &self.0
76 }
77}
78
79impl std::ops::Deref for ProviderName {
80 type Target = str;
81
82 fn deref(&self) -> &str {
83 &self.0
84 }
85}
86
87impl PartialEq<str> for ProviderName {
88 fn eq(&self, other: &str) -> bool {
89 self.0 == other
90 }
91}
92
93impl PartialEq<&str> for ProviderName {
94 fn eq(&self, other: &&str) -> bool {
95 self.0 == *other
96 }
97}
98
99fn default_response_cache_ttl_secs() -> u64 {
100 3600
101}
102
103fn default_semantic_cache_threshold() -> f32 {
104 0.95
105}
106
107fn default_semantic_cache_max_candidates() -> u32 {
108 10
109}
110
111fn default_router_ema_alpha() -> f64 {
112 0.1
113}
114
115fn default_router_reorder_interval() -> u64 {
116 10
117}
118
119fn default_embedding_model() -> String {
120 "qwen3-embedding".into()
121}
122
123fn default_candle_source() -> String {
124 "huggingface".into()
125}
126
127fn default_chat_template() -> String {
128 "chatml".into()
129}
130
131fn default_candle_device() -> String {
132 "cpu".into()
133}
134
135fn default_temperature() -> f64 {
136 0.7
137}
138
139fn default_max_tokens() -> usize {
140 2048
141}
142
143fn default_seed() -> u64 {
144 42
145}
146
147fn default_repeat_penalty() -> f32 {
148 1.1
149}
150
151fn default_repeat_last_n() -> usize {
152 64
153}
154
155fn default_cascade_quality_threshold() -> f64 {
156 0.5
157}
158
159fn default_cascade_max_escalations() -> u8 {
160 2
161}
162
163fn default_cascade_window_size() -> usize {
164 50
165}
166
167fn default_reputation_decay_factor() -> f64 {
168 0.95
169}
170
171fn default_reputation_weight() -> f64 {
172 0.3
173}
174
175fn default_reputation_min_observations() -> u64 {
176 5
177}
178
179#[must_use]
181pub fn default_stt_provider() -> String {
182 String::new()
183}
184
185#[must_use]
187pub fn default_stt_language() -> String {
188 "auto".into()
189}
190
191#[must_use]
193pub fn get_default_embedding_model() -> String {
194 default_embedding_model()
195}
196
197#[must_use]
199pub fn get_default_response_cache_ttl_secs() -> u64 {
200 default_response_cache_ttl_secs()
201}
202
203#[must_use]
205pub fn get_default_router_ema_alpha() -> f64 {
206 default_router_ema_alpha()
207}
208
209#[must_use]
211pub fn get_default_router_reorder_interval() -> u64 {
212 default_router_reorder_interval()
213}
214
215#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize, Serialize)]
228#[serde(rename_all = "lowercase")]
229pub enum ProviderKind {
230 Ollama,
232 Claude,
234 OpenAi,
236 Gemini,
238 Candle,
240 Compatible,
242}
243
244impl ProviderKind {
245 #[must_use]
256 pub fn as_str(self) -> &'static str {
257 match self {
258 Self::Ollama => "ollama",
259 Self::Claude => "claude",
260 Self::OpenAi => "openai",
261 Self::Gemini => "gemini",
262 Self::Candle => "candle",
263 Self::Compatible => "compatible",
264 }
265 }
266}
267
268impl std::fmt::Display for ProviderKind {
269 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
270 f.write_str(self.as_str())
271 }
272}
273
274#[derive(Debug, Deserialize, Serialize)]
298pub struct LlmConfig {
299 #[serde(default, skip_serializing_if = "Vec::is_empty")]
301 pub providers: Vec<ProviderEntry>,
302
303 #[serde(default, skip_serializing_if = "is_routing_none")]
305 pub routing: LlmRoutingStrategy,
306
307 #[serde(default, skip_serializing_if = "std::collections::HashMap::is_empty")]
309 pub routes: std::collections::HashMap<String, Vec<String>>,
310
311 #[serde(default = "default_embedding_model_opt")]
312 pub embedding_model: String,
313 #[serde(default, skip_serializing_if = "Option::is_none")]
314 pub candle: Option<CandleConfig>,
315 #[serde(default)]
316 pub stt: Option<SttConfig>,
317 #[serde(default)]
318 pub response_cache_enabled: bool,
319 #[serde(default = "default_response_cache_ttl_secs")]
320 pub response_cache_ttl_secs: u64,
321 #[serde(default)]
323 pub semantic_cache_enabled: bool,
324 #[serde(default = "default_semantic_cache_threshold")]
330 pub semantic_cache_threshold: f32,
331 #[serde(default = "default_semantic_cache_max_candidates")]
344 pub semantic_cache_max_candidates: u32,
345 #[serde(default)]
346 pub router_ema_enabled: bool,
347 #[serde(default = "default_router_ema_alpha")]
348 pub router_ema_alpha: f64,
349 #[serde(default = "default_router_reorder_interval")]
350 pub router_reorder_interval: u64,
351 #[serde(default, skip_serializing_if = "Option::is_none")]
353 pub router: Option<RouterConfig>,
354 #[serde(default, skip_serializing_if = "Option::is_none")]
357 pub instruction_file: Option<std::path::PathBuf>,
358 #[serde(default, skip_serializing_if = "Option::is_none")]
362 pub summary_model: Option<String>,
363 #[serde(default, skip_serializing_if = "Option::is_none")]
365 pub summary_provider: Option<ProviderEntry>,
366
367 #[serde(default, skip_serializing_if = "Option::is_none")]
369 pub complexity_routing: Option<ComplexityRoutingConfig>,
370
371 #[serde(default, skip_serializing_if = "Option::is_none")]
373 pub coe: Option<CoeConfig>,
374}
375
376fn default_embedding_model_opt() -> String {
377 default_embedding_model()
378}
379
380#[allow(clippy::trivially_copy_pass_by_ref)]
381fn is_routing_none(s: &LlmRoutingStrategy) -> bool {
382 *s == LlmRoutingStrategy::None
383}
384
385impl LlmConfig {
386 #[must_use]
388 pub fn effective_provider(&self) -> ProviderKind {
389 self.providers
390 .first()
391 .map_or(ProviderKind::Ollama, |e| e.provider_type)
392 }
393
394 #[must_use]
396 pub fn effective_base_url(&self) -> &str {
397 self.providers
398 .first()
399 .and_then(|e| e.base_url.as_deref())
400 .unwrap_or("http://localhost:11434")
401 }
402
403 #[must_use]
405 pub fn effective_model(&self) -> &str {
406 self.providers
407 .first()
408 .and_then(|e| e.model.as_deref())
409 .unwrap_or("qwen3:8b")
410 }
411
412 #[must_use]
420 pub fn stt_provider_entry(&self) -> Option<&ProviderEntry> {
421 let name_hint = self.stt.as_ref().map_or("", |s| s.provider.as_str());
422 if name_hint.is_empty() {
423 self.providers.iter().find(|p| p.stt_model.is_some())
424 } else {
425 self.providers
426 .iter()
427 .find(|p| p.effective_name() == name_hint && p.stt_model.is_some())
428 }
429 }
430
431 pub fn check_legacy_format(&self) -> Result<(), crate::error::ConfigError> {
437 Ok(())
438 }
439
440 pub fn validate_stt(&self) -> Result<(), crate::error::ConfigError> {
446 use crate::error::ConfigError;
447
448 let Some(stt) = &self.stt else {
449 return Ok(());
450 };
451 if stt.provider.is_empty() {
452 return Ok(());
453 }
454 let found = self
455 .providers
456 .iter()
457 .find(|p| p.effective_name() == stt.provider);
458 match found {
459 None => {
460 return Err(ConfigError::Validation(format!(
461 "[llm.stt].provider = {:?} does not match any [[llm.providers]] entry",
462 stt.provider
463 )));
464 }
465 Some(entry) if entry.stt_model.is_none() => {
466 tracing::warn!(
467 provider = stt.provider,
468 "[[llm.providers]] entry exists but has no `stt_model` — STT will not be activated"
469 );
470 }
471 _ => {}
472 }
473 Ok(())
474 }
475}
476
477#[derive(Debug, Clone, Deserialize, Serialize)]
490pub struct SttConfig {
491 #[serde(default = "default_stt_provider")]
494 pub provider: String,
495 #[serde(default = "default_stt_language")]
497 pub language: String,
498}
499
500#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Deserialize, Serialize)]
502#[serde(rename_all = "lowercase")]
503pub enum RouterStrategyConfig {
504 #[default]
506 Ema,
507 Thompson,
509 Cascade,
511 Bandit,
513}
514
515#[derive(Debug, Clone, Deserialize, Serialize)]
528pub struct AsiConfig {
529 #[serde(default)]
531 pub enabled: bool,
532
533 #[serde(default = "default_asi_window")]
535 pub window: usize,
536
537 #[serde(default = "default_asi_coherence_threshold")]
539 pub coherence_threshold: f32,
540
541 #[serde(default = "default_asi_penalty_weight")]
546 pub penalty_weight: f32,
547}
548
549fn default_asi_window() -> usize {
550 5
551}
552
553fn default_asi_coherence_threshold() -> f32 {
554 0.7
555}
556
557fn default_asi_penalty_weight() -> f32 {
558 0.3
559}
560
561impl Default for AsiConfig {
562 fn default() -> Self {
563 Self {
564 enabled: false,
565 window: default_asi_window(),
566 coherence_threshold: default_asi_coherence_threshold(),
567 penalty_weight: default_asi_penalty_weight(),
568 }
569 }
570}
571
572#[derive(Debug, Clone, Deserialize, Serialize)]
574pub struct RouterConfig {
575 #[serde(default)]
577 pub strategy: RouterStrategyConfig,
578 #[serde(default)]
586 pub thompson_state_path: Option<String>,
587 #[serde(default)]
589 pub cascade: Option<CascadeConfig>,
590 #[serde(default)]
592 pub reputation: Option<ReputationConfig>,
593 #[serde(default)]
595 pub bandit: Option<BanditConfig>,
596 #[serde(default)]
605 pub quality_gate: Option<f32>,
606 #[serde(default)]
608 pub asi: Option<AsiConfig>,
609 #[serde(default = "default_embed_concurrency")]
615 pub embed_concurrency: usize,
616}
617
618fn default_embed_concurrency() -> usize {
619 4
620}
621
622#[derive(Debug, Clone, Deserialize, Serialize)]
629pub struct ReputationConfig {
630 #[serde(default)]
632 pub enabled: bool,
633 #[serde(default = "default_reputation_decay_factor")]
636 pub decay_factor: f64,
637 #[serde(default = "default_reputation_weight")]
644 pub weight: f64,
645 #[serde(default = "default_reputation_min_observations")]
647 pub min_observations: u64,
648 #[serde(default)]
650 pub state_path: Option<String>,
651}
652
653#[derive(Debug, Clone, Deserialize, Serialize)]
664pub struct CascadeConfig {
665 #[serde(default = "default_cascade_quality_threshold")]
668 pub quality_threshold: f64,
669
670 #[serde(default = "default_cascade_max_escalations")]
674 pub max_escalations: u8,
675
676 #[serde(default)]
680 pub classifier_mode: CascadeClassifierMode,
681
682 #[serde(default = "default_cascade_window_size")]
684 pub window_size: usize,
685
686 #[serde(default)]
690 pub max_cascade_tokens: Option<u32>,
691
692 #[serde(default, skip_serializing_if = "Option::is_none")]
697 pub cost_tiers: Option<Vec<String>>,
698}
699
700impl Default for CascadeConfig {
701 fn default() -> Self {
702 Self {
703 quality_threshold: default_cascade_quality_threshold(),
704 max_escalations: default_cascade_max_escalations(),
705 classifier_mode: CascadeClassifierMode::default(),
706 window_size: default_cascade_window_size(),
707 max_cascade_tokens: None,
708 cost_tiers: None,
709 }
710 }
711}
712
713#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Deserialize, Serialize)]
715#[serde(rename_all = "lowercase")]
716pub enum CascadeClassifierMode {
717 #[default]
720 Heuristic,
721 Judge,
724}
725
726fn default_bandit_alpha() -> f32 {
727 1.0
728}
729
730fn default_bandit_dim() -> usize {
731 32
732}
733
734fn default_bandit_cost_weight() -> f32 {
735 0.1
736}
737
738fn default_bandit_decay_factor() -> f32 {
739 1.0
740}
741
742fn default_bandit_embedding_timeout_ms() -> u64 {
743 50
744}
745
746fn default_bandit_cache_size() -> usize {
747 512
748}
749
750#[derive(Debug, Clone, Deserialize, Serialize)]
763pub struct BanditConfig {
764 #[serde(default = "default_bandit_alpha")]
767 pub alpha: f32,
768
769 #[serde(default = "default_bandit_dim")]
776 pub dim: usize,
777
778 #[serde(default = "default_bandit_cost_weight")]
781 pub cost_weight: f32,
782
783 #[serde(default = "default_bandit_decay_factor")]
786 pub decay_factor: f32,
787
788 #[serde(default)]
794 pub embedding_provider: ProviderName,
795
796 #[serde(default = "default_bandit_embedding_timeout_ms")]
799 pub embedding_timeout_ms: u64,
800
801 #[serde(default = "default_bandit_cache_size")]
803 pub cache_size: usize,
804
805 #[serde(default)]
812 pub state_path: Option<String>,
813
814 #[serde(default = "default_bandit_memory_confidence_threshold")]
820 pub memory_confidence_threshold: f32,
821
822 #[serde(default)]
828 pub warmup_queries: Option<u64>,
829}
830
831fn default_bandit_memory_confidence_threshold() -> f32 {
832 0.9
833}
834
835impl Default for BanditConfig {
836 fn default() -> Self {
837 Self {
838 alpha: default_bandit_alpha(),
839 dim: default_bandit_dim(),
840 cost_weight: default_bandit_cost_weight(),
841 decay_factor: default_bandit_decay_factor(),
842 embedding_provider: ProviderName::default(),
843 embedding_timeout_ms: default_bandit_embedding_timeout_ms(),
844 cache_size: default_bandit_cache_size(),
845 state_path: None,
846 memory_confidence_threshold: default_bandit_memory_confidence_threshold(),
847 warmup_queries: None,
848 }
849 }
850}
851
852#[derive(Debug, Deserialize, Serialize)]
853pub struct CandleConfig {
854 #[serde(default = "default_candle_source")]
855 pub source: String,
856 #[serde(default)]
857 pub local_path: String,
858 #[serde(default)]
859 pub filename: Option<String>,
860 #[serde(default = "default_chat_template")]
861 pub chat_template: String,
862 #[serde(default = "default_candle_device")]
863 pub device: String,
864 #[serde(default)]
865 pub embedding_repo: Option<String>,
866 #[serde(default)]
870 pub hf_token: Option<String>,
871 #[serde(default)]
872 pub generation: GenerationParams,
873 #[serde(default = "default_inference_timeout_secs")]
882 pub inference_timeout_secs: u64,
883}
884
885fn default_inference_timeout_secs() -> u64 {
886 120
887}
888
889#[derive(Debug, Clone, Deserialize, Serialize)]
893pub struct GenerationParams {
894 #[serde(default = "default_temperature")]
896 pub temperature: f64,
897 #[serde(default)]
900 pub top_p: Option<f64>,
901 #[serde(default)]
904 pub top_k: Option<usize>,
905 #[serde(default = "default_max_tokens")]
908 pub max_tokens: usize,
909 #[serde(default = "default_seed")]
911 pub seed: u64,
912 #[serde(default = "default_repeat_penalty")]
914 pub repeat_penalty: f32,
915 #[serde(default = "default_repeat_last_n")]
917 pub repeat_last_n: usize,
918}
919
920pub const MAX_TOKENS_CAP: usize = 32768;
922
923impl GenerationParams {
924 #[must_use]
935 pub fn capped_max_tokens(&self) -> usize {
936 self.max_tokens.min(MAX_TOKENS_CAP)
937 }
938}
939
940impl Default for GenerationParams {
941 fn default() -> Self {
942 Self {
943 temperature: default_temperature(),
944 top_p: None,
945 top_k: None,
946 max_tokens: default_max_tokens(),
947 seed: default_seed(),
948 repeat_penalty: default_repeat_penalty(),
949 repeat_last_n: default_repeat_last_n(),
950 }
951 }
952}
953
954#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Deserialize, Serialize)]
958#[serde(rename_all = "lowercase")]
959pub enum LlmRoutingStrategy {
960 #[default]
962 None,
963 Ema,
965 Thompson,
967 Cascade,
969 Task,
971 Triage,
973 Bandit,
975}
976
977fn default_triage_timeout_secs() -> u64 {
978 5
979}
980
981fn default_max_triage_tokens() -> u32 {
982 50
983}
984
985fn default_true() -> bool {
986 true
987}
988
989#[derive(Debug, Clone, Default, Deserialize, Serialize)]
991pub struct TierMapping {
992 pub simple: Option<String>,
993 pub medium: Option<String>,
994 pub complex: Option<String>,
995 pub expert: Option<String>,
996}
997
998#[derive(Debug, Clone, Deserialize, Serialize)]
1019pub struct ComplexityRoutingConfig {
1020 #[serde(default)]
1022 pub triage_provider: Option<ProviderName>,
1023
1024 #[serde(default = "default_true")]
1026 pub bypass_single_provider: bool,
1027
1028 #[serde(default)]
1030 pub tiers: TierMapping,
1031
1032 #[serde(default = "default_max_triage_tokens")]
1034 pub max_triage_tokens: u32,
1035
1036 #[serde(default = "default_triage_timeout_secs")]
1039 pub triage_timeout_secs: u64,
1040
1041 #[serde(default)]
1044 pub fallback_strategy: Option<String>,
1045}
1046
1047impl Default for ComplexityRoutingConfig {
1048 fn default() -> Self {
1049 Self {
1050 triage_provider: None,
1051 bypass_single_provider: true,
1052 tiers: TierMapping::default(),
1053 max_triage_tokens: default_max_triage_tokens(),
1054 triage_timeout_secs: default_triage_timeout_secs(),
1055 fallback_strategy: None,
1056 }
1057 }
1058}
1059
1060#[derive(Debug, Clone, Deserialize, Serialize)]
1078#[serde(default)]
1079pub struct CoeConfig {
1080 pub enabled: bool,
1082 pub intra_threshold: f64,
1084 pub inter_threshold: f64,
1086 pub shadow_sample_rate: f64,
1088 pub secondary_provider: ProviderName,
1090 pub embed_provider: ProviderName,
1092}
1093
1094impl Default for CoeConfig {
1095 fn default() -> Self {
1096 Self {
1097 enabled: false,
1098 intra_threshold: 0.8,
1099 inter_threshold: 0.20,
1100 shadow_sample_rate: 0.1,
1101 secondary_provider: ProviderName::default(),
1102 embed_provider: ProviderName::default(),
1103 }
1104 }
1105}
1106
1107#[derive(Debug, Clone, Deserialize, Serialize)]
1110pub struct CandleInlineConfig {
1111 #[serde(default = "default_candle_source")]
1112 pub source: String,
1113 #[serde(default)]
1114 pub local_path: String,
1115 #[serde(default)]
1116 pub filename: Option<String>,
1117 #[serde(default = "default_chat_template")]
1118 pub chat_template: String,
1119 #[serde(default = "default_candle_device")]
1120 pub device: String,
1121 #[serde(default)]
1122 pub embedding_repo: Option<String>,
1123 #[serde(default)]
1125 pub hf_token: Option<String>,
1126 #[serde(default)]
1127 pub generation: GenerationParams,
1128 #[serde(default = "default_inference_timeout_secs")]
1133 pub inference_timeout_secs: u64,
1134}
1135
1136impl Default for CandleInlineConfig {
1137 fn default() -> Self {
1138 Self {
1139 source: default_candle_source(),
1140 local_path: String::new(),
1141 filename: None,
1142 chat_template: default_chat_template(),
1143 device: default_candle_device(),
1144 embedding_repo: None,
1145 hf_token: None,
1146 generation: GenerationParams::default(),
1147 inference_timeout_secs: default_inference_timeout_secs(),
1148 }
1149 }
1150}
1151
1152#[derive(Debug, Clone, Deserialize, Serialize)]
1158#[allow(clippy::struct_excessive_bools)]
1159pub struct ProviderEntry {
1160 #[serde(rename = "type")]
1162 pub provider_type: ProviderKind,
1163
1164 #[serde(default)]
1166 pub name: Option<String>,
1167
1168 #[serde(default)]
1170 pub model: Option<String>,
1171
1172 #[serde(default)]
1174 pub base_url: Option<String>,
1175
1176 #[serde(default)]
1178 pub max_tokens: Option<u32>,
1179
1180 #[serde(default)]
1182 pub embedding_model: Option<String>,
1183
1184 #[serde(default)]
1187 pub stt_model: Option<String>,
1188
1189 #[serde(default)]
1191 pub embed: bool,
1192
1193 #[serde(default)]
1195 pub default: bool,
1196
1197 #[serde(default)]
1199 pub thinking: Option<ThinkingConfig>,
1200 #[serde(default)]
1201 pub server_compaction: bool,
1202 #[serde(default)]
1203 pub enable_extended_context: bool,
1204 #[serde(default)]
1207 pub prompt_cache_ttl: Option<CacheTtl>,
1208
1209 #[serde(default)]
1211 pub reasoning_effort: Option<String>,
1212
1213 #[serde(default)]
1215 pub thinking_level: Option<GeminiThinkingLevel>,
1216 #[serde(default)]
1217 pub thinking_budget: Option<i32>,
1218 #[serde(default)]
1219 pub include_thoughts: Option<bool>,
1220
1221 #[serde(default)]
1223 pub api_key: Option<String>,
1224
1225 #[serde(default)]
1227 pub candle: Option<CandleInlineConfig>,
1228
1229 #[serde(default)]
1231 pub vision_model: Option<String>,
1232
1233 #[serde(default)]
1235 pub instruction_file: Option<std::path::PathBuf>,
1236}
1237
1238impl Default for ProviderEntry {
1239 fn default() -> Self {
1240 Self {
1241 provider_type: ProviderKind::Ollama,
1242 name: None,
1243 model: None,
1244 base_url: None,
1245 max_tokens: None,
1246 embedding_model: None,
1247 stt_model: None,
1248 embed: false,
1249 default: false,
1250 thinking: None,
1251 server_compaction: false,
1252 enable_extended_context: false,
1253 prompt_cache_ttl: None,
1254 reasoning_effort: None,
1255 thinking_level: None,
1256 thinking_budget: None,
1257 include_thoughts: None,
1258 api_key: None,
1259 candle: None,
1260 vision_model: None,
1261 instruction_file: None,
1262 }
1263 }
1264}
1265
1266impl ProviderEntry {
1267 #[must_use]
1269 pub fn effective_name(&self) -> String {
1270 self.name
1271 .clone()
1272 .unwrap_or_else(|| self.provider_type.as_str().to_owned())
1273 }
1274
1275 #[must_use]
1280 pub fn effective_model(&self) -> String {
1281 if let Some(ref m) = self.model {
1282 return m.clone();
1283 }
1284 match self.provider_type {
1285 ProviderKind::Ollama => "qwen3:8b".to_owned(),
1286 ProviderKind::Claude => "claude-haiku-4-5-20251001".to_owned(),
1287 ProviderKind::OpenAi => "gpt-4o-mini".to_owned(),
1288 ProviderKind::Gemini => "gemini-2.0-flash".to_owned(),
1289 ProviderKind::Compatible | ProviderKind::Candle => String::new(),
1290 }
1291 }
1292
1293 pub fn validate(&self) -> Result<(), crate::error::ConfigError> {
1300 use crate::error::ConfigError;
1301
1302 if self.provider_type == ProviderKind::Compatible && self.name.is_none() {
1304 return Err(ConfigError::Validation(
1305 "[[llm.providers]] entry with type=\"compatible\" must set `name`".into(),
1306 ));
1307 }
1308
1309 match self.provider_type {
1311 ProviderKind::Ollama => {
1312 if self.thinking.is_some() {
1313 tracing::warn!(
1314 provider = self.effective_name(),
1315 "field `thinking` is only used by Claude providers"
1316 );
1317 }
1318 if self.reasoning_effort.is_some() {
1319 tracing::warn!(
1320 provider = self.effective_name(),
1321 "field `reasoning_effort` is only used by OpenAI providers"
1322 );
1323 }
1324 if self.thinking_level.is_some() || self.thinking_budget.is_some() {
1325 tracing::warn!(
1326 provider = self.effective_name(),
1327 "fields `thinking_level`/`thinking_budget` are only used by Gemini providers"
1328 );
1329 }
1330 }
1331 ProviderKind::Claude => {
1332 if self.reasoning_effort.is_some() {
1333 tracing::warn!(
1334 provider = self.effective_name(),
1335 "field `reasoning_effort` is only used by OpenAI providers"
1336 );
1337 }
1338 if self.thinking_level.is_some() || self.thinking_budget.is_some() {
1339 tracing::warn!(
1340 provider = self.effective_name(),
1341 "fields `thinking_level`/`thinking_budget` are only used by Gemini providers"
1342 );
1343 }
1344 }
1345 ProviderKind::OpenAi => {
1346 if self.thinking.is_some() {
1347 tracing::warn!(
1348 provider = self.effective_name(),
1349 "field `thinking` is only used by Claude providers"
1350 );
1351 }
1352 if self.thinking_level.is_some() || self.thinking_budget.is_some() {
1353 tracing::warn!(
1354 provider = self.effective_name(),
1355 "fields `thinking_level`/`thinking_budget` are only used by Gemini providers"
1356 );
1357 }
1358 }
1359 ProviderKind::Gemini => {
1360 if self.thinking.is_some() {
1361 tracing::warn!(
1362 provider = self.effective_name(),
1363 "field `thinking` is only used by Claude providers"
1364 );
1365 }
1366 if self.reasoning_effort.is_some() {
1367 tracing::warn!(
1368 provider = self.effective_name(),
1369 "field `reasoning_effort` is only used by OpenAI providers"
1370 );
1371 }
1372 }
1373 _ => {}
1374 }
1375
1376 if self.stt_model.is_some() && self.provider_type == ProviderKind::Ollama {
1379 tracing::warn!(
1380 provider = self.effective_name(),
1381 "field `stt_model` is set on an Ollama provider; Ollama does not support the \
1382 Whisper STT API — use OpenAI, compatible, or candle instead"
1383 );
1384 }
1385
1386 Ok(())
1387 }
1388}
1389
1390pub fn validate_pool(entries: &[ProviderEntry]) -> Result<(), crate::error::ConfigError> {
1400 use crate::error::ConfigError;
1401 use std::collections::HashSet;
1402
1403 if entries.is_empty() {
1404 return Err(ConfigError::Validation(
1405 "at least one LLM provider must be configured in [[llm.providers]]".into(),
1406 ));
1407 }
1408
1409 let default_count = entries.iter().filter(|e| e.default).count();
1410 if default_count > 1 {
1411 return Err(ConfigError::Validation(
1412 "only one [[llm.providers]] entry can be marked `default = true`".into(),
1413 ));
1414 }
1415
1416 let mut seen_names: HashSet<String> = HashSet::new();
1417 for entry in entries {
1418 let name = entry.effective_name();
1419 if !seen_names.insert(name.clone()) {
1420 return Err(ConfigError::Validation(format!(
1421 "duplicate provider name \"{name}\" in [[llm.providers]]"
1422 )));
1423 }
1424 entry.validate()?;
1425 }
1426
1427 Ok(())
1428}
1429
1430#[cfg(test)]
1431mod tests {
1432 use super::*;
1433
1434 fn ollama_entry() -> ProviderEntry {
1435 ProviderEntry {
1436 provider_type: ProviderKind::Ollama,
1437 name: Some("ollama".into()),
1438 model: Some("qwen3:8b".into()),
1439 ..Default::default()
1440 }
1441 }
1442
1443 fn claude_entry() -> ProviderEntry {
1444 ProviderEntry {
1445 provider_type: ProviderKind::Claude,
1446 name: Some("claude".into()),
1447 model: Some("claude-sonnet-4-6".into()),
1448 max_tokens: Some(8192),
1449 ..Default::default()
1450 }
1451 }
1452
1453 #[test]
1456 fn validate_ollama_valid() {
1457 assert!(ollama_entry().validate().is_ok());
1458 }
1459
1460 #[test]
1461 fn validate_claude_valid() {
1462 assert!(claude_entry().validate().is_ok());
1463 }
1464
1465 #[test]
1466 fn validate_compatible_without_name_errors() {
1467 let entry = ProviderEntry {
1468 provider_type: ProviderKind::Compatible,
1469 name: None,
1470 ..Default::default()
1471 };
1472 let err = entry.validate().unwrap_err();
1473 assert!(
1474 err.to_string().contains("compatible"),
1475 "error should mention compatible: {err}"
1476 );
1477 }
1478
1479 #[test]
1480 fn validate_compatible_with_name_ok() {
1481 let entry = ProviderEntry {
1482 provider_type: ProviderKind::Compatible,
1483 name: Some("my-proxy".into()),
1484 base_url: Some("http://localhost:8080".into()),
1485 model: Some("gpt-4o".into()),
1486 max_tokens: Some(4096),
1487 ..Default::default()
1488 };
1489 assert!(entry.validate().is_ok());
1490 }
1491
1492 #[test]
1493 fn validate_openai_valid() {
1494 let entry = ProviderEntry {
1495 provider_type: ProviderKind::OpenAi,
1496 name: Some("openai".into()),
1497 model: Some("gpt-4o".into()),
1498 max_tokens: Some(4096),
1499 ..Default::default()
1500 };
1501 assert!(entry.validate().is_ok());
1502 }
1503
1504 #[test]
1505 fn validate_gemini_valid() {
1506 let entry = ProviderEntry {
1507 provider_type: ProviderKind::Gemini,
1508 name: Some("gemini".into()),
1509 model: Some("gemini-2.0-flash".into()),
1510 ..Default::default()
1511 };
1512 assert!(entry.validate().is_ok());
1513 }
1514
1515 #[test]
1518 fn validate_pool_empty_errors() {
1519 let err = validate_pool(&[]).unwrap_err();
1520 assert!(err.to_string().contains("at least one"), "{err}");
1521 }
1522
1523 #[test]
1524 fn validate_pool_single_entry_ok() {
1525 assert!(validate_pool(&[ollama_entry()]).is_ok());
1526 }
1527
1528 #[test]
1529 fn validate_pool_duplicate_names_errors() {
1530 let a = ollama_entry();
1531 let b = ollama_entry(); let err = validate_pool(&[a, b]).unwrap_err();
1533 assert!(err.to_string().contains("duplicate"), "{err}");
1534 }
1535
1536 #[test]
1537 fn validate_pool_multiple_defaults_errors() {
1538 let mut a = ollama_entry();
1539 let mut b = claude_entry();
1540 a.default = true;
1541 b.default = true;
1542 let err = validate_pool(&[a, b]).unwrap_err();
1543 assert!(err.to_string().contains("default"), "{err}");
1544 }
1545
1546 #[test]
1547 fn validate_pool_two_different_providers_ok() {
1548 assert!(validate_pool(&[ollama_entry(), claude_entry()]).is_ok());
1549 }
1550
1551 #[test]
1552 fn validate_pool_propagates_entry_error() {
1553 let bad = ProviderEntry {
1554 provider_type: ProviderKind::Compatible,
1555 name: None, ..Default::default()
1557 };
1558 assert!(validate_pool(&[bad]).is_err());
1559 }
1560
1561 #[test]
1564 fn effective_model_returns_explicit_when_set() {
1565 let entry = ProviderEntry {
1566 provider_type: ProviderKind::Claude,
1567 model: Some("claude-sonnet-4-6".into()),
1568 ..Default::default()
1569 };
1570 assert_eq!(entry.effective_model(), "claude-sonnet-4-6");
1571 }
1572
1573 #[test]
1574 fn effective_model_ollama_default_when_none() {
1575 let entry = ProviderEntry {
1576 provider_type: ProviderKind::Ollama,
1577 model: None,
1578 ..Default::default()
1579 };
1580 assert_eq!(entry.effective_model(), "qwen3:8b");
1581 }
1582
1583 #[test]
1584 fn effective_model_claude_default_when_none() {
1585 let entry = ProviderEntry {
1586 provider_type: ProviderKind::Claude,
1587 model: None,
1588 ..Default::default()
1589 };
1590 assert_eq!(entry.effective_model(), "claude-haiku-4-5-20251001");
1591 }
1592
1593 #[test]
1594 fn effective_model_openai_default_when_none() {
1595 let entry = ProviderEntry {
1596 provider_type: ProviderKind::OpenAi,
1597 model: None,
1598 ..Default::default()
1599 };
1600 assert_eq!(entry.effective_model(), "gpt-4o-mini");
1601 }
1602
1603 #[test]
1604 fn effective_model_gemini_default_when_none() {
1605 let entry = ProviderEntry {
1606 provider_type: ProviderKind::Gemini,
1607 model: None,
1608 ..Default::default()
1609 };
1610 assert_eq!(entry.effective_model(), "gemini-2.0-flash");
1611 }
1612
1613 fn parse_llm(toml: &str) -> LlmConfig {
1617 #[derive(serde::Deserialize)]
1618 struct Wrapper {
1619 llm: LlmConfig,
1620 }
1621 toml::from_str::<Wrapper>(toml).unwrap().llm
1622 }
1623
1624 #[test]
1625 fn check_legacy_format_new_format_ok() {
1626 let cfg = parse_llm(
1627 r#"
1628[llm]
1629
1630[[llm.providers]]
1631type = "ollama"
1632model = "qwen3:8b"
1633"#,
1634 );
1635 assert!(cfg.check_legacy_format().is_ok());
1636 }
1637
1638 #[test]
1639 fn check_legacy_format_empty_providers_no_legacy_ok() {
1640 let cfg = parse_llm("[llm]\n");
1642 assert!(cfg.check_legacy_format().is_ok());
1643 }
1644
1645 #[test]
1648 fn effective_provider_falls_back_to_ollama_when_no_providers() {
1649 let cfg = parse_llm("[llm]\n");
1650 assert_eq!(cfg.effective_provider(), ProviderKind::Ollama);
1651 }
1652
1653 #[test]
1654 fn effective_provider_reads_from_providers_first() {
1655 let cfg = parse_llm(
1656 r#"
1657[llm]
1658
1659[[llm.providers]]
1660type = "claude"
1661model = "claude-sonnet-4-6"
1662"#,
1663 );
1664 assert_eq!(cfg.effective_provider(), ProviderKind::Claude);
1665 }
1666
1667 #[test]
1668 fn effective_model_reads_from_providers_first() {
1669 let cfg = parse_llm(
1670 r#"
1671[llm]
1672
1673[[llm.providers]]
1674type = "ollama"
1675model = "qwen3:8b"
1676"#,
1677 );
1678 assert_eq!(cfg.effective_model(), "qwen3:8b");
1679 }
1680
1681 #[test]
1682 fn effective_base_url_default_when_absent() {
1683 let cfg = parse_llm("[llm]\n");
1684 assert_eq!(cfg.effective_base_url(), "http://localhost:11434");
1685 }
1686
1687 #[test]
1688 fn effective_base_url_from_providers_entry() {
1689 let cfg = parse_llm(
1690 r#"
1691[llm]
1692
1693[[llm.providers]]
1694type = "ollama"
1695base_url = "http://myhost:11434"
1696"#,
1697 );
1698 assert_eq!(cfg.effective_base_url(), "http://myhost:11434");
1699 }
1700
1701 #[test]
1704 fn complexity_routing_defaults() {
1705 let cr = ComplexityRoutingConfig::default();
1706 assert!(
1707 cr.bypass_single_provider,
1708 "bypass_single_provider must default to true"
1709 );
1710 assert_eq!(cr.triage_timeout_secs, 5);
1711 assert_eq!(cr.max_triage_tokens, 50);
1712 assert!(cr.triage_provider.is_none());
1713 assert!(cr.tiers.simple.is_none());
1714 }
1715
1716 #[test]
1717 fn complexity_routing_toml_round_trip() {
1718 let cfg = parse_llm(
1719 r#"
1720[llm]
1721routing = "triage"
1722
1723[llm.complexity_routing]
1724triage_provider = "fast"
1725bypass_single_provider = false
1726triage_timeout_secs = 10
1727max_triage_tokens = 100
1728
1729[llm.complexity_routing.tiers]
1730simple = "fast"
1731medium = "medium"
1732complex = "large"
1733expert = "opus"
1734"#,
1735 );
1736 assert!(matches!(cfg.routing, LlmRoutingStrategy::Triage));
1737 let cr = cfg
1738 .complexity_routing
1739 .expect("complexity_routing must be present");
1740 assert_eq!(cr.triage_provider.as_deref(), Some("fast"));
1741 assert!(!cr.bypass_single_provider);
1742 assert_eq!(cr.triage_timeout_secs, 10);
1743 assert_eq!(cr.max_triage_tokens, 100);
1744 assert_eq!(cr.tiers.simple.as_deref(), Some("fast"));
1745 assert_eq!(cr.tiers.medium.as_deref(), Some("medium"));
1746 assert_eq!(cr.tiers.complex.as_deref(), Some("large"));
1747 assert_eq!(cr.tiers.expert.as_deref(), Some("opus"));
1748 }
1749
1750 #[test]
1751 fn complexity_routing_partial_tiers_toml() {
1752 let cfg = parse_llm(
1754 r#"
1755[llm]
1756routing = "triage"
1757
1758[llm.complexity_routing.tiers]
1759simple = "haiku"
1760complex = "sonnet"
1761"#,
1762 );
1763 let cr = cfg
1764 .complexity_routing
1765 .expect("complexity_routing must be present");
1766 assert_eq!(cr.tiers.simple.as_deref(), Some("haiku"));
1767 assert!(cr.tiers.medium.is_none());
1768 assert_eq!(cr.tiers.complex.as_deref(), Some("sonnet"));
1769 assert!(cr.tiers.expert.is_none());
1770 assert!(cr.bypass_single_provider);
1772 assert_eq!(cr.triage_timeout_secs, 5);
1773 }
1774
1775 #[test]
1776 fn routing_strategy_triage_deserialized() {
1777 let cfg = parse_llm(
1778 r#"
1779[llm]
1780routing = "triage"
1781"#,
1782 );
1783 assert!(matches!(cfg.routing, LlmRoutingStrategy::Triage));
1784 }
1785
1786 #[test]
1789 fn stt_provider_entry_by_name_match() {
1790 let cfg = parse_llm(
1791 r#"
1792[llm]
1793
1794[[llm.providers]]
1795type = "openai"
1796name = "quality"
1797model = "gpt-5.4"
1798stt_model = "gpt-4o-mini-transcribe"
1799
1800[llm.stt]
1801provider = "quality"
1802"#,
1803 );
1804 let entry = cfg.stt_provider_entry().expect("should find stt provider");
1805 assert_eq!(entry.effective_name(), "quality");
1806 assert_eq!(entry.stt_model.as_deref(), Some("gpt-4o-mini-transcribe"));
1807 }
1808
1809 #[test]
1810 fn stt_provider_entry_auto_detect_when_provider_empty() {
1811 let cfg = parse_llm(
1812 r#"
1813[llm]
1814
1815[[llm.providers]]
1816type = "openai"
1817name = "openai-stt"
1818stt_model = "whisper-1"
1819
1820[llm.stt]
1821provider = ""
1822"#,
1823 );
1824 let entry = cfg.stt_provider_entry().expect("should auto-detect");
1825 assert_eq!(entry.effective_name(), "openai-stt");
1826 }
1827
1828 #[test]
1829 fn stt_provider_entry_auto_detect_no_stt_section() {
1830 let cfg = parse_llm(
1831 r#"
1832[llm]
1833
1834[[llm.providers]]
1835type = "openai"
1836name = "openai-stt"
1837stt_model = "whisper-1"
1838"#,
1839 );
1840 let entry = cfg.stt_provider_entry().expect("should auto-detect");
1842 assert_eq!(entry.effective_name(), "openai-stt");
1843 }
1844
1845 #[test]
1846 fn stt_provider_entry_none_when_no_stt_model() {
1847 let cfg = parse_llm(
1848 r#"
1849[llm]
1850
1851[[llm.providers]]
1852type = "openai"
1853name = "quality"
1854model = "gpt-5.4"
1855"#,
1856 );
1857 assert!(cfg.stt_provider_entry().is_none());
1858 }
1859
1860 #[test]
1861 fn stt_provider_entry_name_mismatch_falls_back_to_none() {
1862 let cfg = parse_llm(
1864 r#"
1865[llm]
1866
1867[[llm.providers]]
1868type = "openai"
1869name = "quality"
1870model = "gpt-5.4"
1871
1872[[llm.providers]]
1873type = "openai"
1874name = "openai-stt"
1875stt_model = "whisper-1"
1876
1877[llm.stt]
1878provider = "quality"
1879"#,
1880 );
1881 assert!(cfg.stt_provider_entry().is_none());
1883 }
1884
1885 #[test]
1886 fn stt_config_deserializes_new_slim_format() {
1887 let cfg = parse_llm(
1888 r#"
1889[llm]
1890
1891[[llm.providers]]
1892type = "openai"
1893name = "quality"
1894stt_model = "whisper-1"
1895
1896[llm.stt]
1897provider = "quality"
1898language = "en"
1899"#,
1900 );
1901 let stt = cfg.stt.as_ref().expect("stt section present");
1902 assert_eq!(stt.provider, "quality");
1903 assert_eq!(stt.language, "en");
1904 }
1905
1906 #[test]
1907 fn stt_config_default_provider_is_empty() {
1908 assert_eq!(default_stt_provider(), "");
1910 }
1911
1912 #[test]
1913 fn validate_stt_missing_provider_ok() {
1914 let cfg = parse_llm("[llm]\n");
1915 assert!(cfg.validate_stt().is_ok());
1916 }
1917
1918 #[test]
1919 fn validate_stt_valid_reference() {
1920 let cfg = parse_llm(
1921 r#"
1922[llm]
1923
1924[[llm.providers]]
1925type = "openai"
1926name = "quality"
1927stt_model = "whisper-1"
1928
1929[llm.stt]
1930provider = "quality"
1931"#,
1932 );
1933 assert!(cfg.validate_stt().is_ok());
1934 }
1935
1936 #[test]
1937 fn validate_stt_nonexistent_provider_errors() {
1938 let cfg = parse_llm(
1939 r#"
1940[llm]
1941
1942[[llm.providers]]
1943type = "openai"
1944name = "quality"
1945model = "gpt-5.4"
1946
1947[llm.stt]
1948provider = "nonexistent"
1949"#,
1950 );
1951 assert!(cfg.validate_stt().is_err());
1952 }
1953
1954 #[test]
1955 fn validate_stt_provider_exists_but_no_stt_model_returns_ok_with_warn() {
1956 let cfg = parse_llm(
1958 r#"
1959[llm]
1960
1961[[llm.providers]]
1962type = "openai"
1963name = "quality"
1964model = "gpt-5.4"
1965
1966[llm.stt]
1967provider = "quality"
1968"#,
1969 );
1970 assert!(cfg.validate_stt().is_ok());
1972 assert!(
1974 cfg.stt_provider_entry().is_none(),
1975 "stt_provider_entry must be None when provider has no stt_model"
1976 );
1977 }
1978
1979 #[test]
1982 fn bandit_warmup_queries_explicit_value_is_deserialized() {
1983 let cfg = parse_llm(
1984 r#"
1985[llm]
1986
1987[llm.router]
1988strategy = "bandit"
1989
1990[llm.router.bandit]
1991warmup_queries = 50
1992"#,
1993 );
1994 let bandit = cfg
1995 .router
1996 .expect("router section must be present")
1997 .bandit
1998 .expect("bandit section must be present");
1999 assert_eq!(
2000 bandit.warmup_queries,
2001 Some(50),
2002 "warmup_queries = 50 must deserialize to Some(50)"
2003 );
2004 }
2005
2006 #[test]
2007 fn bandit_warmup_queries_explicit_null_is_none() {
2008 let cfg = parse_llm(
2011 r#"
2012[llm]
2013
2014[llm.router]
2015strategy = "bandit"
2016
2017[llm.router.bandit]
2018warmup_queries = 0
2019"#,
2020 );
2021 let bandit = cfg
2022 .router
2023 .expect("router section must be present")
2024 .bandit
2025 .expect("bandit section must be present");
2026 assert_eq!(
2028 bandit.warmup_queries,
2029 Some(0),
2030 "warmup_queries = 0 must deserialize to Some(0)"
2031 );
2032 }
2033
2034 #[test]
2035 fn bandit_warmup_queries_missing_field_defaults_to_none() {
2036 let cfg = parse_llm(
2038 r#"
2039[llm]
2040
2041[llm.router]
2042strategy = "bandit"
2043
2044[llm.router.bandit]
2045alpha = 1.5
2046"#,
2047 );
2048 let bandit = cfg
2049 .router
2050 .expect("router section must be present")
2051 .bandit
2052 .expect("bandit section must be present");
2053 assert_eq!(
2054 bandit.warmup_queries, None,
2055 "omitted warmup_queries must default to None"
2056 );
2057 }
2058
2059 #[test]
2060 fn provider_name_new_and_as_str() {
2061 let n = ProviderName::new("fast");
2062 assert_eq!(n.as_str(), "fast");
2063 assert!(!n.is_empty());
2064 }
2065
2066 #[test]
2067 fn provider_name_default_is_empty() {
2068 let n = ProviderName::default();
2069 assert!(n.is_empty());
2070 assert_eq!(n.as_str(), "");
2071 }
2072
2073 #[test]
2074 fn provider_name_deref_to_str() {
2075 let n = ProviderName::new("quality");
2076 let s: &str = &n;
2077 assert_eq!(s, "quality");
2078 }
2079
2080 #[test]
2081 fn provider_name_partial_eq_str() {
2082 let n = ProviderName::new("fast");
2083 assert_eq!(n, "fast");
2084 assert_ne!(n, "slow");
2085 }
2086
2087 #[test]
2088 fn provider_name_serde_roundtrip() {
2089 let n = ProviderName::new("my-provider");
2090 let json = serde_json::to_string(&n).expect("serialize");
2091 assert_eq!(json, "\"my-provider\"");
2092 let back: ProviderName = serde_json::from_str(&json).expect("deserialize");
2093 assert_eq!(back, n);
2094 }
2095
2096 #[test]
2097 fn provider_name_serde_empty_roundtrip() {
2098 let n = ProviderName::default();
2099 let json = serde_json::to_string(&n).expect("serialize");
2100 assert_eq!(json, "\"\"");
2101 let back: ProviderName = serde_json::from_str(&json).expect("deserialize");
2102 assert_eq!(back, n);
2103 assert!(back.is_empty());
2104 }
2105}