1use std::fmt;
5
6use serde::{Deserialize, Serialize};
7use zeph_llm::{GeminiThinkingLevel, ThinkingConfig};
8
9#[derive(Debug, Clone, Default, PartialEq, Eq, Hash, Serialize, Deserialize)]
14#[serde(transparent)]
15pub struct ProviderName(String);
16
17impl ProviderName {
18 #[must_use]
19 pub fn new(name: impl Into<String>) -> Self {
20 Self(name.into())
21 }
22
23 #[must_use]
24 pub fn is_empty(&self) -> bool {
25 self.0.is_empty()
26 }
27
28 #[must_use]
29 pub fn as_str(&self) -> &str {
30 &self.0
31 }
32}
33
34impl fmt::Display for ProviderName {
35 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
36 self.0.fmt(f)
37 }
38}
39
40impl AsRef<str> for ProviderName {
41 fn as_ref(&self) -> &str {
42 &self.0
43 }
44}
45
46impl std::ops::Deref for ProviderName {
47 type Target = str;
48
49 fn deref(&self) -> &str {
50 &self.0
51 }
52}
53
54impl PartialEq<str> for ProviderName {
55 fn eq(&self, other: &str) -> bool {
56 self.0 == other
57 }
58}
59
60impl PartialEq<&str> for ProviderName {
61 fn eq(&self, other: &&str) -> bool {
62 self.0 == *other
63 }
64}
65
66fn default_response_cache_ttl_secs() -> u64 {
67 3600
68}
69
70fn default_semantic_cache_threshold() -> f32 {
71 0.95
72}
73
74fn default_semantic_cache_max_candidates() -> u32 {
75 10
76}
77
78fn default_router_ema_alpha() -> f64 {
79 0.1
80}
81
82fn default_router_reorder_interval() -> u64 {
83 10
84}
85
86fn default_embedding_model() -> String {
87 "qwen3-embedding".into()
88}
89
90fn default_candle_source() -> String {
91 "huggingface".into()
92}
93
94fn default_chat_template() -> String {
95 "chatml".into()
96}
97
98fn default_candle_device() -> String {
99 "cpu".into()
100}
101
102fn default_temperature() -> f64 {
103 0.7
104}
105
106fn default_max_tokens() -> usize {
107 2048
108}
109
110fn default_seed() -> u64 {
111 42
112}
113
114fn default_repeat_penalty() -> f32 {
115 1.1
116}
117
118fn default_repeat_last_n() -> usize {
119 64
120}
121
122fn default_cascade_quality_threshold() -> f64 {
123 0.5
124}
125
126fn default_cascade_max_escalations() -> u8 {
127 2
128}
129
130fn default_cascade_window_size() -> usize {
131 50
132}
133
134fn default_reputation_decay_factor() -> f64 {
135 0.95
136}
137
138fn default_reputation_weight() -> f64 {
139 0.3
140}
141
142fn default_reputation_min_observations() -> u64 {
143 5
144}
145
146#[must_use]
147pub fn default_stt_provider() -> String {
148 String::new()
149}
150
151#[must_use]
152pub fn default_stt_language() -> String {
153 "auto".into()
154}
155
156#[must_use]
157pub fn get_default_embedding_model() -> String {
158 default_embedding_model()
159}
160
161#[must_use]
162pub fn get_default_response_cache_ttl_secs() -> u64 {
163 default_response_cache_ttl_secs()
164}
165
166#[must_use]
167pub fn get_default_router_ema_alpha() -> f64 {
168 default_router_ema_alpha()
169}
170
171#[must_use]
172pub fn get_default_router_reorder_interval() -> u64 {
173 default_router_reorder_interval()
174}
175
176#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize, Serialize)]
178#[serde(rename_all = "lowercase")]
179pub enum ProviderKind {
180 Ollama,
181 Claude,
182 OpenAi,
183 Gemini,
184 Candle,
185 Compatible,
186}
187
188impl ProviderKind {
189 #[must_use]
190 pub fn as_str(self) -> &'static str {
191 match self {
192 Self::Ollama => "ollama",
193 Self::Claude => "claude",
194 Self::OpenAi => "openai",
195 Self::Gemini => "gemini",
196 Self::Candle => "candle",
197 Self::Compatible => "compatible",
198 }
199 }
200}
201
202impl std::fmt::Display for ProviderKind {
203 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
204 f.write_str(self.as_str())
205 }
206}
207
208#[derive(Debug, Deserialize, Serialize)]
209pub struct LlmConfig {
210 #[serde(default, skip_serializing_if = "Vec::is_empty")]
212 pub providers: Vec<ProviderEntry>,
213
214 #[serde(default, skip_serializing_if = "is_routing_none")]
216 pub routing: LlmRoutingStrategy,
217
218 #[serde(default, skip_serializing_if = "std::collections::HashMap::is_empty")]
220 pub routes: std::collections::HashMap<String, Vec<String>>,
221
222 #[serde(default = "default_embedding_model_opt")]
223 pub embedding_model: String,
224 #[serde(default, skip_serializing_if = "Option::is_none")]
225 pub candle: Option<CandleConfig>,
226 #[serde(default)]
227 pub stt: Option<SttConfig>,
228 #[serde(default)]
229 pub response_cache_enabled: bool,
230 #[serde(default = "default_response_cache_ttl_secs")]
231 pub response_cache_ttl_secs: u64,
232 #[serde(default)]
234 pub semantic_cache_enabled: bool,
235 #[serde(default = "default_semantic_cache_threshold")]
241 pub semantic_cache_threshold: f32,
242 #[serde(default = "default_semantic_cache_max_candidates")]
255 pub semantic_cache_max_candidates: u32,
256 #[serde(default)]
257 pub router_ema_enabled: bool,
258 #[serde(default = "default_router_ema_alpha")]
259 pub router_ema_alpha: f64,
260 #[serde(default = "default_router_reorder_interval")]
261 pub router_reorder_interval: u64,
262 #[serde(default, skip_serializing_if = "Option::is_none")]
264 pub router: Option<RouterConfig>,
265 #[serde(default, skip_serializing_if = "Option::is_none")]
268 pub instruction_file: Option<std::path::PathBuf>,
269 #[serde(default, skip_serializing_if = "Option::is_none")]
273 pub summary_model: Option<String>,
274 #[serde(default, skip_serializing_if = "Option::is_none")]
276 pub summary_provider: Option<ProviderEntry>,
277
278 #[serde(default, skip_serializing_if = "Option::is_none")]
280 pub complexity_routing: Option<ComplexityRoutingConfig>,
281}
282
283fn default_embedding_model_opt() -> String {
284 default_embedding_model()
285}
286
287#[allow(clippy::trivially_copy_pass_by_ref)]
288fn is_routing_none(s: &LlmRoutingStrategy) -> bool {
289 *s == LlmRoutingStrategy::None
290}
291
292impl LlmConfig {
293 #[must_use]
295 pub fn effective_provider(&self) -> ProviderKind {
296 self.providers
297 .first()
298 .map_or(ProviderKind::Ollama, |e| e.provider_type)
299 }
300
301 #[must_use]
303 pub fn effective_base_url(&self) -> &str {
304 self.providers
305 .first()
306 .and_then(|e| e.base_url.as_deref())
307 .unwrap_or("http://localhost:11434")
308 }
309
310 #[must_use]
312 pub fn effective_model(&self) -> &str {
313 self.providers
314 .first()
315 .and_then(|e| e.model.as_deref())
316 .unwrap_or("qwen3:8b")
317 }
318
319 #[must_use]
327 pub fn stt_provider_entry(&self) -> Option<&ProviderEntry> {
328 let name_hint = self.stt.as_ref().map_or("", |s| s.provider.as_str());
329 if name_hint.is_empty() {
330 self.providers.iter().find(|p| p.stt_model.is_some())
331 } else {
332 self.providers
333 .iter()
334 .find(|p| p.effective_name() == name_hint && p.stt_model.is_some())
335 }
336 }
337
338 pub fn check_legacy_format(&self) -> Result<(), crate::error::ConfigError> {
344 Ok(())
345 }
346
347 pub fn validate_stt(&self) -> Result<(), crate::error::ConfigError> {
353 use crate::error::ConfigError;
354
355 let Some(stt) = &self.stt else {
356 return Ok(());
357 };
358 if stt.provider.is_empty() {
359 return Ok(());
360 }
361 let found = self
362 .providers
363 .iter()
364 .find(|p| p.effective_name() == stt.provider);
365 match found {
366 None => {
367 return Err(ConfigError::Validation(format!(
368 "[llm.stt].provider = {:?} does not match any [[llm.providers]] entry",
369 stt.provider
370 )));
371 }
372 Some(entry) if entry.stt_model.is_none() => {
373 tracing::warn!(
374 provider = stt.provider,
375 "[[llm.providers]] entry exists but has no `stt_model` — STT will not be activated"
376 );
377 }
378 _ => {}
379 }
380 Ok(())
381 }
382}
383
384#[derive(Debug, Clone, Deserialize, Serialize)]
385pub struct SttConfig {
386 #[serde(default = "default_stt_provider")]
389 pub provider: String,
390 #[serde(default = "default_stt_language")]
392 pub language: String,
393}
394
395#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Deserialize, Serialize)]
397#[serde(rename_all = "lowercase")]
398pub enum RouterStrategyConfig {
399 #[default]
401 Ema,
402 Thompson,
404 Cascade,
406 Bandit,
408}
409
410#[derive(Debug, Clone, Deserialize, Serialize)]
423pub struct AsiConfig {
424 #[serde(default)]
426 pub enabled: bool,
427
428 #[serde(default = "default_asi_window")]
430 pub window: usize,
431
432 #[serde(default = "default_asi_coherence_threshold")]
434 pub coherence_threshold: f32,
435
436 #[serde(default = "default_asi_penalty_weight")]
441 pub penalty_weight: f32,
442}
443
444fn default_asi_window() -> usize {
445 5
446}
447
448fn default_asi_coherence_threshold() -> f32 {
449 0.7
450}
451
452fn default_asi_penalty_weight() -> f32 {
453 0.3
454}
455
456impl Default for AsiConfig {
457 fn default() -> Self {
458 Self {
459 enabled: false,
460 window: default_asi_window(),
461 coherence_threshold: default_asi_coherence_threshold(),
462 penalty_weight: default_asi_penalty_weight(),
463 }
464 }
465}
466
467#[derive(Debug, Clone, Deserialize, Serialize)]
469pub struct RouterConfig {
470 #[serde(default)]
472 pub strategy: RouterStrategyConfig,
473 #[serde(default)]
481 pub thompson_state_path: Option<String>,
482 #[serde(default)]
484 pub cascade: Option<CascadeConfig>,
485 #[serde(default)]
487 pub reputation: Option<ReputationConfig>,
488 #[serde(default)]
490 pub bandit: Option<BanditConfig>,
491 #[serde(default)]
500 pub quality_gate: Option<f32>,
501 #[serde(default)]
503 pub asi: Option<AsiConfig>,
504 #[serde(default = "default_embed_concurrency")]
510 pub embed_concurrency: usize,
511}
512
513fn default_embed_concurrency() -> usize {
514 4
515}
516
517#[derive(Debug, Clone, Deserialize, Serialize)]
524pub struct ReputationConfig {
525 #[serde(default)]
527 pub enabled: bool,
528 #[serde(default = "default_reputation_decay_factor")]
531 pub decay_factor: f64,
532 #[serde(default = "default_reputation_weight")]
539 pub weight: f64,
540 #[serde(default = "default_reputation_min_observations")]
542 pub min_observations: u64,
543 #[serde(default)]
545 pub state_path: Option<String>,
546}
547
548#[derive(Debug, Clone, Deserialize, Serialize)]
559pub struct CascadeConfig {
560 #[serde(default = "default_cascade_quality_threshold")]
563 pub quality_threshold: f64,
564
565 #[serde(default = "default_cascade_max_escalations")]
569 pub max_escalations: u8,
570
571 #[serde(default)]
575 pub classifier_mode: CascadeClassifierMode,
576
577 #[serde(default = "default_cascade_window_size")]
579 pub window_size: usize,
580
581 #[serde(default)]
585 pub max_cascade_tokens: Option<u32>,
586
587 #[serde(default, skip_serializing_if = "Option::is_none")]
592 pub cost_tiers: Option<Vec<String>>,
593}
594
595impl Default for CascadeConfig {
596 fn default() -> Self {
597 Self {
598 quality_threshold: default_cascade_quality_threshold(),
599 max_escalations: default_cascade_max_escalations(),
600 classifier_mode: CascadeClassifierMode::default(),
601 window_size: default_cascade_window_size(),
602 max_cascade_tokens: None,
603 cost_tiers: None,
604 }
605 }
606}
607
608#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Deserialize, Serialize)]
610#[serde(rename_all = "lowercase")]
611pub enum CascadeClassifierMode {
612 #[default]
615 Heuristic,
616 Judge,
619}
620
621fn default_bandit_alpha() -> f32 {
622 1.0
623}
624
625fn default_bandit_dim() -> usize {
626 32
627}
628
629fn default_bandit_cost_weight() -> f32 {
630 0.1
631}
632
633fn default_bandit_decay_factor() -> f32 {
634 1.0
635}
636
637fn default_bandit_embedding_timeout_ms() -> u64 {
638 50
639}
640
641fn default_bandit_cache_size() -> usize {
642 512
643}
644
645#[derive(Debug, Clone, Deserialize, Serialize)]
658pub struct BanditConfig {
659 #[serde(default = "default_bandit_alpha")]
662 pub alpha: f32,
663
664 #[serde(default = "default_bandit_dim")]
671 pub dim: usize,
672
673 #[serde(default = "default_bandit_cost_weight")]
676 pub cost_weight: f32,
677
678 #[serde(default = "default_bandit_decay_factor")]
681 pub decay_factor: f32,
682
683 #[serde(default)]
689 pub embedding_provider: ProviderName,
690
691 #[serde(default = "default_bandit_embedding_timeout_ms")]
694 pub embedding_timeout_ms: u64,
695
696 #[serde(default = "default_bandit_cache_size")]
698 pub cache_size: usize,
699
700 #[serde(default)]
707 pub state_path: Option<String>,
708
709 #[serde(default = "default_bandit_memory_confidence_threshold")]
715 pub memory_confidence_threshold: f32,
716
717 #[serde(default)]
723 pub warmup_queries: Option<u64>,
724}
725
726fn default_bandit_memory_confidence_threshold() -> f32 {
727 0.9
728}
729
730impl Default for BanditConfig {
731 fn default() -> Self {
732 Self {
733 alpha: default_bandit_alpha(),
734 dim: default_bandit_dim(),
735 cost_weight: default_bandit_cost_weight(),
736 decay_factor: default_bandit_decay_factor(),
737 embedding_provider: ProviderName::default(),
738 embedding_timeout_ms: default_bandit_embedding_timeout_ms(),
739 cache_size: default_bandit_cache_size(),
740 state_path: None,
741 memory_confidence_threshold: default_bandit_memory_confidence_threshold(),
742 warmup_queries: None,
743 }
744 }
745}
746
747#[derive(Debug, Deserialize, Serialize)]
748pub struct CandleConfig {
749 #[serde(default = "default_candle_source")]
750 pub source: String,
751 #[serde(default)]
752 pub local_path: String,
753 #[serde(default)]
754 pub filename: Option<String>,
755 #[serde(default = "default_chat_template")]
756 pub chat_template: String,
757 #[serde(default = "default_candle_device")]
758 pub device: String,
759 #[serde(default)]
760 pub embedding_repo: Option<String>,
761 #[serde(default)]
765 pub hf_token: Option<String>,
766 #[serde(default)]
767 pub generation: GenerationParams,
768}
769
770#[derive(Debug, Clone, Deserialize, Serialize)]
771pub struct GenerationParams {
772 #[serde(default = "default_temperature")]
773 pub temperature: f64,
774 #[serde(default)]
775 pub top_p: Option<f64>,
776 #[serde(default)]
777 pub top_k: Option<usize>,
778 #[serde(default = "default_max_tokens")]
779 pub max_tokens: usize,
780 #[serde(default = "default_seed")]
781 pub seed: u64,
782 #[serde(default = "default_repeat_penalty")]
783 pub repeat_penalty: f32,
784 #[serde(default = "default_repeat_last_n")]
785 pub repeat_last_n: usize,
786}
787
788pub const MAX_TOKENS_CAP: usize = 32768;
789
790impl GenerationParams {
791 #[must_use]
792 pub fn capped_max_tokens(&self) -> usize {
793 self.max_tokens.min(MAX_TOKENS_CAP)
794 }
795}
796
797impl Default for GenerationParams {
798 fn default() -> Self {
799 Self {
800 temperature: default_temperature(),
801 top_p: None,
802 top_k: None,
803 max_tokens: default_max_tokens(),
804 seed: default_seed(),
805 repeat_penalty: default_repeat_penalty(),
806 repeat_last_n: default_repeat_last_n(),
807 }
808 }
809}
810
811#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Deserialize, Serialize)]
815#[serde(rename_all = "lowercase")]
816pub enum LlmRoutingStrategy {
817 #[default]
819 None,
820 Ema,
822 Thompson,
824 Cascade,
826 Task,
828 Triage,
830 Bandit,
832}
833
834fn default_triage_timeout_secs() -> u64 {
835 5
836}
837
838fn default_max_triage_tokens() -> u32 {
839 50
840}
841
842fn default_true() -> bool {
843 true
844}
845
846#[derive(Debug, Clone, Default, Deserialize, Serialize)]
848pub struct TierMapping {
849 pub simple: Option<String>,
850 pub medium: Option<String>,
851 pub complex: Option<String>,
852 pub expert: Option<String>,
853}
854
855#[derive(Debug, Clone, Deserialize, Serialize)]
876pub struct ComplexityRoutingConfig {
877 #[serde(default)]
879 pub triage_provider: Option<ProviderName>,
880
881 #[serde(default = "default_true")]
883 pub bypass_single_provider: bool,
884
885 #[serde(default)]
887 pub tiers: TierMapping,
888
889 #[serde(default = "default_max_triage_tokens")]
891 pub max_triage_tokens: u32,
892
893 #[serde(default = "default_triage_timeout_secs")]
896 pub triage_timeout_secs: u64,
897
898 #[serde(default)]
901 pub fallback_strategy: Option<String>,
902}
903
904impl Default for ComplexityRoutingConfig {
905 fn default() -> Self {
906 Self {
907 triage_provider: None,
908 bypass_single_provider: true,
909 tiers: TierMapping::default(),
910 max_triage_tokens: default_max_triage_tokens(),
911 triage_timeout_secs: default_triage_timeout_secs(),
912 fallback_strategy: None,
913 }
914 }
915}
916
917#[derive(Debug, Clone, Deserialize, Serialize)]
920pub struct CandleInlineConfig {
921 #[serde(default = "default_candle_source")]
922 pub source: String,
923 #[serde(default)]
924 pub local_path: String,
925 #[serde(default)]
926 pub filename: Option<String>,
927 #[serde(default = "default_chat_template")]
928 pub chat_template: String,
929 #[serde(default = "default_candle_device")]
930 pub device: String,
931 #[serde(default)]
932 pub embedding_repo: Option<String>,
933 #[serde(default)]
935 pub hf_token: Option<String>,
936 #[serde(default)]
937 pub generation: GenerationParams,
938}
939
940impl Default for CandleInlineConfig {
941 fn default() -> Self {
942 Self {
943 source: default_candle_source(),
944 local_path: String::new(),
945 filename: None,
946 chat_template: default_chat_template(),
947 device: default_candle_device(),
948 embedding_repo: None,
949 hf_token: None,
950 generation: GenerationParams::default(),
951 }
952 }
953}
954
955#[derive(Debug, Clone, Deserialize, Serialize)]
961#[allow(clippy::struct_excessive_bools)]
962pub struct ProviderEntry {
963 #[serde(rename = "type")]
965 pub provider_type: ProviderKind,
966
967 #[serde(default)]
969 pub name: Option<String>,
970
971 #[serde(default)]
973 pub model: Option<String>,
974
975 #[serde(default)]
977 pub base_url: Option<String>,
978
979 #[serde(default)]
981 pub max_tokens: Option<u32>,
982
983 #[serde(default)]
985 pub embedding_model: Option<String>,
986
987 #[serde(default)]
990 pub stt_model: Option<String>,
991
992 #[serde(default)]
994 pub embed: bool,
995
996 #[serde(default)]
998 pub default: bool,
999
1000 #[serde(default)]
1002 pub thinking: Option<ThinkingConfig>,
1003 #[serde(default)]
1004 pub server_compaction: bool,
1005 #[serde(default)]
1006 pub enable_extended_context: bool,
1007
1008 #[serde(default)]
1010 pub reasoning_effort: Option<String>,
1011
1012 #[serde(default)]
1014 pub thinking_level: Option<GeminiThinkingLevel>,
1015 #[serde(default)]
1016 pub thinking_budget: Option<i32>,
1017 #[serde(default)]
1018 pub include_thoughts: Option<bool>,
1019
1020 #[serde(default)]
1022 pub api_key: Option<String>,
1023
1024 #[serde(default)]
1026 pub candle: Option<CandleInlineConfig>,
1027
1028 #[serde(default)]
1030 pub vision_model: Option<String>,
1031
1032 #[serde(default)]
1034 pub instruction_file: Option<std::path::PathBuf>,
1035}
1036
1037impl Default for ProviderEntry {
1038 fn default() -> Self {
1039 Self {
1040 provider_type: ProviderKind::Ollama,
1041 name: None,
1042 model: None,
1043 base_url: None,
1044 max_tokens: None,
1045 embedding_model: None,
1046 stt_model: None,
1047 embed: false,
1048 default: false,
1049 thinking: None,
1050 server_compaction: false,
1051 enable_extended_context: false,
1052 reasoning_effort: None,
1053 thinking_level: None,
1054 thinking_budget: None,
1055 include_thoughts: None,
1056 api_key: None,
1057 candle: None,
1058 vision_model: None,
1059 instruction_file: None,
1060 }
1061 }
1062}
1063
1064impl ProviderEntry {
1065 #[must_use]
1067 pub fn effective_name(&self) -> String {
1068 self.name
1069 .clone()
1070 .unwrap_or_else(|| self.provider_type.as_str().to_owned())
1071 }
1072
1073 #[must_use]
1078 pub fn effective_model(&self) -> String {
1079 if let Some(ref m) = self.model {
1080 return m.clone();
1081 }
1082 match self.provider_type {
1083 ProviderKind::Ollama => "qwen3:8b".to_owned(),
1084 ProviderKind::Claude => "claude-haiku-4-5-20251001".to_owned(),
1085 ProviderKind::OpenAi => "gpt-4o-mini".to_owned(),
1086 ProviderKind::Gemini => "gemini-2.0-flash".to_owned(),
1087 ProviderKind::Compatible | ProviderKind::Candle => String::new(),
1088 }
1089 }
1090
1091 pub fn validate(&self) -> Result<(), crate::error::ConfigError> {
1098 use crate::error::ConfigError;
1099
1100 if self.provider_type == ProviderKind::Compatible && self.name.is_none() {
1102 return Err(ConfigError::Validation(
1103 "[[llm.providers]] entry with type=\"compatible\" must set `name`".into(),
1104 ));
1105 }
1106
1107 match self.provider_type {
1109 ProviderKind::Ollama => {
1110 if self.thinking.is_some() {
1111 tracing::warn!(
1112 provider = self.effective_name(),
1113 "field `thinking` is only used by Claude providers"
1114 );
1115 }
1116 if self.reasoning_effort.is_some() {
1117 tracing::warn!(
1118 provider = self.effective_name(),
1119 "field `reasoning_effort` is only used by OpenAI providers"
1120 );
1121 }
1122 if self.thinking_level.is_some() || self.thinking_budget.is_some() {
1123 tracing::warn!(
1124 provider = self.effective_name(),
1125 "fields `thinking_level`/`thinking_budget` are only used by Gemini providers"
1126 );
1127 }
1128 }
1129 ProviderKind::Claude => {
1130 if self.reasoning_effort.is_some() {
1131 tracing::warn!(
1132 provider = self.effective_name(),
1133 "field `reasoning_effort` is only used by OpenAI providers"
1134 );
1135 }
1136 if self.thinking_level.is_some() || self.thinking_budget.is_some() {
1137 tracing::warn!(
1138 provider = self.effective_name(),
1139 "fields `thinking_level`/`thinking_budget` are only used by Gemini providers"
1140 );
1141 }
1142 }
1143 ProviderKind::OpenAi => {
1144 if self.thinking.is_some() {
1145 tracing::warn!(
1146 provider = self.effective_name(),
1147 "field `thinking` is only used by Claude providers"
1148 );
1149 }
1150 if self.thinking_level.is_some() || self.thinking_budget.is_some() {
1151 tracing::warn!(
1152 provider = self.effective_name(),
1153 "fields `thinking_level`/`thinking_budget` are only used by Gemini providers"
1154 );
1155 }
1156 }
1157 ProviderKind::Gemini => {
1158 if self.thinking.is_some() {
1159 tracing::warn!(
1160 provider = self.effective_name(),
1161 "field `thinking` is only used by Claude providers"
1162 );
1163 }
1164 if self.reasoning_effort.is_some() {
1165 tracing::warn!(
1166 provider = self.effective_name(),
1167 "field `reasoning_effort` is only used by OpenAI providers"
1168 );
1169 }
1170 }
1171 _ => {}
1172 }
1173
1174 if self.stt_model.is_some() && self.provider_type == ProviderKind::Ollama {
1177 tracing::warn!(
1178 provider = self.effective_name(),
1179 "field `stt_model` is set on an Ollama provider; Ollama does not support the \
1180 Whisper STT API — use OpenAI, compatible, or candle instead"
1181 );
1182 }
1183
1184 Ok(())
1185 }
1186}
1187
1188pub fn validate_pool(entries: &[ProviderEntry]) -> Result<(), crate::error::ConfigError> {
1198 use crate::error::ConfigError;
1199 use std::collections::HashSet;
1200
1201 if entries.is_empty() {
1202 return Err(ConfigError::Validation(
1203 "at least one LLM provider must be configured in [[llm.providers]]".into(),
1204 ));
1205 }
1206
1207 let default_count = entries.iter().filter(|e| e.default).count();
1208 if default_count > 1 {
1209 return Err(ConfigError::Validation(
1210 "only one [[llm.providers]] entry can be marked `default = true`".into(),
1211 ));
1212 }
1213
1214 let mut seen_names: HashSet<String> = HashSet::new();
1215 for entry in entries {
1216 let name = entry.effective_name();
1217 if !seen_names.insert(name.clone()) {
1218 return Err(ConfigError::Validation(format!(
1219 "duplicate provider name \"{name}\" in [[llm.providers]]"
1220 )));
1221 }
1222 entry.validate()?;
1223 }
1224
1225 Ok(())
1226}
1227
1228#[cfg(test)]
1229mod tests {
1230 use super::*;
1231
1232 fn ollama_entry() -> ProviderEntry {
1233 ProviderEntry {
1234 provider_type: ProviderKind::Ollama,
1235 name: Some("ollama".into()),
1236 model: Some("qwen3:8b".into()),
1237 ..Default::default()
1238 }
1239 }
1240
1241 fn claude_entry() -> ProviderEntry {
1242 ProviderEntry {
1243 provider_type: ProviderKind::Claude,
1244 name: Some("claude".into()),
1245 model: Some("claude-sonnet-4-6".into()),
1246 max_tokens: Some(8192),
1247 ..Default::default()
1248 }
1249 }
1250
1251 #[test]
1254 fn validate_ollama_valid() {
1255 assert!(ollama_entry().validate().is_ok());
1256 }
1257
1258 #[test]
1259 fn validate_claude_valid() {
1260 assert!(claude_entry().validate().is_ok());
1261 }
1262
1263 #[test]
1264 fn validate_compatible_without_name_errors() {
1265 let entry = ProviderEntry {
1266 provider_type: ProviderKind::Compatible,
1267 name: None,
1268 ..Default::default()
1269 };
1270 let err = entry.validate().unwrap_err();
1271 assert!(
1272 err.to_string().contains("compatible"),
1273 "error should mention compatible: {err}"
1274 );
1275 }
1276
1277 #[test]
1278 fn validate_compatible_with_name_ok() {
1279 let entry = ProviderEntry {
1280 provider_type: ProviderKind::Compatible,
1281 name: Some("my-proxy".into()),
1282 base_url: Some("http://localhost:8080".into()),
1283 model: Some("gpt-4o".into()),
1284 max_tokens: Some(4096),
1285 ..Default::default()
1286 };
1287 assert!(entry.validate().is_ok());
1288 }
1289
1290 #[test]
1291 fn validate_openai_valid() {
1292 let entry = ProviderEntry {
1293 provider_type: ProviderKind::OpenAi,
1294 name: Some("openai".into()),
1295 model: Some("gpt-4o".into()),
1296 max_tokens: Some(4096),
1297 ..Default::default()
1298 };
1299 assert!(entry.validate().is_ok());
1300 }
1301
1302 #[test]
1303 fn validate_gemini_valid() {
1304 let entry = ProviderEntry {
1305 provider_type: ProviderKind::Gemini,
1306 name: Some("gemini".into()),
1307 model: Some("gemini-2.0-flash".into()),
1308 ..Default::default()
1309 };
1310 assert!(entry.validate().is_ok());
1311 }
1312
1313 #[test]
1316 fn validate_pool_empty_errors() {
1317 let err = validate_pool(&[]).unwrap_err();
1318 assert!(err.to_string().contains("at least one"), "{err}");
1319 }
1320
1321 #[test]
1322 fn validate_pool_single_entry_ok() {
1323 assert!(validate_pool(&[ollama_entry()]).is_ok());
1324 }
1325
1326 #[test]
1327 fn validate_pool_duplicate_names_errors() {
1328 let a = ollama_entry();
1329 let b = ollama_entry(); let err = validate_pool(&[a, b]).unwrap_err();
1331 assert!(err.to_string().contains("duplicate"), "{err}");
1332 }
1333
1334 #[test]
1335 fn validate_pool_multiple_defaults_errors() {
1336 let mut a = ollama_entry();
1337 let mut b = claude_entry();
1338 a.default = true;
1339 b.default = true;
1340 let err = validate_pool(&[a, b]).unwrap_err();
1341 assert!(err.to_string().contains("default"), "{err}");
1342 }
1343
1344 #[test]
1345 fn validate_pool_two_different_providers_ok() {
1346 assert!(validate_pool(&[ollama_entry(), claude_entry()]).is_ok());
1347 }
1348
1349 #[test]
1350 fn validate_pool_propagates_entry_error() {
1351 let bad = ProviderEntry {
1352 provider_type: ProviderKind::Compatible,
1353 name: None, ..Default::default()
1355 };
1356 assert!(validate_pool(&[bad]).is_err());
1357 }
1358
1359 #[test]
1362 fn effective_model_returns_explicit_when_set() {
1363 let entry = ProviderEntry {
1364 provider_type: ProviderKind::Claude,
1365 model: Some("claude-sonnet-4-6".into()),
1366 ..Default::default()
1367 };
1368 assert_eq!(entry.effective_model(), "claude-sonnet-4-6");
1369 }
1370
1371 #[test]
1372 fn effective_model_ollama_default_when_none() {
1373 let entry = ProviderEntry {
1374 provider_type: ProviderKind::Ollama,
1375 model: None,
1376 ..Default::default()
1377 };
1378 assert_eq!(entry.effective_model(), "qwen3:8b");
1379 }
1380
1381 #[test]
1382 fn effective_model_claude_default_when_none() {
1383 let entry = ProviderEntry {
1384 provider_type: ProviderKind::Claude,
1385 model: None,
1386 ..Default::default()
1387 };
1388 assert_eq!(entry.effective_model(), "claude-haiku-4-5-20251001");
1389 }
1390
1391 #[test]
1392 fn effective_model_openai_default_when_none() {
1393 let entry = ProviderEntry {
1394 provider_type: ProviderKind::OpenAi,
1395 model: None,
1396 ..Default::default()
1397 };
1398 assert_eq!(entry.effective_model(), "gpt-4o-mini");
1399 }
1400
1401 #[test]
1402 fn effective_model_gemini_default_when_none() {
1403 let entry = ProviderEntry {
1404 provider_type: ProviderKind::Gemini,
1405 model: None,
1406 ..Default::default()
1407 };
1408 assert_eq!(entry.effective_model(), "gemini-2.0-flash");
1409 }
1410
1411 fn parse_llm(toml: &str) -> LlmConfig {
1415 #[derive(serde::Deserialize)]
1416 struct Wrapper {
1417 llm: LlmConfig,
1418 }
1419 toml::from_str::<Wrapper>(toml).unwrap().llm
1420 }
1421
1422 #[test]
1423 fn check_legacy_format_new_format_ok() {
1424 let cfg = parse_llm(
1425 r#"
1426[llm]
1427
1428[[llm.providers]]
1429type = "ollama"
1430model = "qwen3:8b"
1431"#,
1432 );
1433 assert!(cfg.check_legacy_format().is_ok());
1434 }
1435
1436 #[test]
1437 fn check_legacy_format_empty_providers_no_legacy_ok() {
1438 let cfg = parse_llm("[llm]\n");
1440 assert!(cfg.check_legacy_format().is_ok());
1441 }
1442
1443 #[test]
1446 fn effective_provider_falls_back_to_ollama_when_no_providers() {
1447 let cfg = parse_llm("[llm]\n");
1448 assert_eq!(cfg.effective_provider(), ProviderKind::Ollama);
1449 }
1450
1451 #[test]
1452 fn effective_provider_reads_from_providers_first() {
1453 let cfg = parse_llm(
1454 r#"
1455[llm]
1456
1457[[llm.providers]]
1458type = "claude"
1459model = "claude-sonnet-4-6"
1460"#,
1461 );
1462 assert_eq!(cfg.effective_provider(), ProviderKind::Claude);
1463 }
1464
1465 #[test]
1466 fn effective_model_reads_from_providers_first() {
1467 let cfg = parse_llm(
1468 r#"
1469[llm]
1470
1471[[llm.providers]]
1472type = "ollama"
1473model = "qwen3:8b"
1474"#,
1475 );
1476 assert_eq!(cfg.effective_model(), "qwen3:8b");
1477 }
1478
1479 #[test]
1480 fn effective_base_url_default_when_absent() {
1481 let cfg = parse_llm("[llm]\n");
1482 assert_eq!(cfg.effective_base_url(), "http://localhost:11434");
1483 }
1484
1485 #[test]
1486 fn effective_base_url_from_providers_entry() {
1487 let cfg = parse_llm(
1488 r#"
1489[llm]
1490
1491[[llm.providers]]
1492type = "ollama"
1493base_url = "http://myhost:11434"
1494"#,
1495 );
1496 assert_eq!(cfg.effective_base_url(), "http://myhost:11434");
1497 }
1498
1499 #[test]
1502 fn complexity_routing_defaults() {
1503 let cr = ComplexityRoutingConfig::default();
1504 assert!(
1505 cr.bypass_single_provider,
1506 "bypass_single_provider must default to true"
1507 );
1508 assert_eq!(cr.triage_timeout_secs, 5);
1509 assert_eq!(cr.max_triage_tokens, 50);
1510 assert!(cr.triage_provider.is_none());
1511 assert!(cr.tiers.simple.is_none());
1512 }
1513
1514 #[test]
1515 fn complexity_routing_toml_round_trip() {
1516 let cfg = parse_llm(
1517 r#"
1518[llm]
1519routing = "triage"
1520
1521[llm.complexity_routing]
1522triage_provider = "fast"
1523bypass_single_provider = false
1524triage_timeout_secs = 10
1525max_triage_tokens = 100
1526
1527[llm.complexity_routing.tiers]
1528simple = "fast"
1529medium = "medium"
1530complex = "large"
1531expert = "opus"
1532"#,
1533 );
1534 assert!(matches!(cfg.routing, LlmRoutingStrategy::Triage));
1535 let cr = cfg
1536 .complexity_routing
1537 .expect("complexity_routing must be present");
1538 assert_eq!(cr.triage_provider.as_deref(), Some("fast"));
1539 assert!(!cr.bypass_single_provider);
1540 assert_eq!(cr.triage_timeout_secs, 10);
1541 assert_eq!(cr.max_triage_tokens, 100);
1542 assert_eq!(cr.tiers.simple.as_deref(), Some("fast"));
1543 assert_eq!(cr.tiers.medium.as_deref(), Some("medium"));
1544 assert_eq!(cr.tiers.complex.as_deref(), Some("large"));
1545 assert_eq!(cr.tiers.expert.as_deref(), Some("opus"));
1546 }
1547
1548 #[test]
1549 fn complexity_routing_partial_tiers_toml() {
1550 let cfg = parse_llm(
1552 r#"
1553[llm]
1554routing = "triage"
1555
1556[llm.complexity_routing.tiers]
1557simple = "haiku"
1558complex = "sonnet"
1559"#,
1560 );
1561 let cr = cfg
1562 .complexity_routing
1563 .expect("complexity_routing must be present");
1564 assert_eq!(cr.tiers.simple.as_deref(), Some("haiku"));
1565 assert!(cr.tiers.medium.is_none());
1566 assert_eq!(cr.tiers.complex.as_deref(), Some("sonnet"));
1567 assert!(cr.tiers.expert.is_none());
1568 assert!(cr.bypass_single_provider);
1570 assert_eq!(cr.triage_timeout_secs, 5);
1571 }
1572
1573 #[test]
1574 fn routing_strategy_triage_deserialized() {
1575 let cfg = parse_llm(
1576 r#"
1577[llm]
1578routing = "triage"
1579"#,
1580 );
1581 assert!(matches!(cfg.routing, LlmRoutingStrategy::Triage));
1582 }
1583
1584 #[test]
1587 fn stt_provider_entry_by_name_match() {
1588 let cfg = parse_llm(
1589 r#"
1590[llm]
1591
1592[[llm.providers]]
1593type = "openai"
1594name = "quality"
1595model = "gpt-5.4"
1596stt_model = "gpt-4o-mini-transcribe"
1597
1598[llm.stt]
1599provider = "quality"
1600"#,
1601 );
1602 let entry = cfg.stt_provider_entry().expect("should find stt provider");
1603 assert_eq!(entry.effective_name(), "quality");
1604 assert_eq!(entry.stt_model.as_deref(), Some("gpt-4o-mini-transcribe"));
1605 }
1606
1607 #[test]
1608 fn stt_provider_entry_auto_detect_when_provider_empty() {
1609 let cfg = parse_llm(
1610 r#"
1611[llm]
1612
1613[[llm.providers]]
1614type = "openai"
1615name = "openai-stt"
1616stt_model = "whisper-1"
1617
1618[llm.stt]
1619provider = ""
1620"#,
1621 );
1622 let entry = cfg.stt_provider_entry().expect("should auto-detect");
1623 assert_eq!(entry.effective_name(), "openai-stt");
1624 }
1625
1626 #[test]
1627 fn stt_provider_entry_auto_detect_no_stt_section() {
1628 let cfg = parse_llm(
1629 r#"
1630[llm]
1631
1632[[llm.providers]]
1633type = "openai"
1634name = "openai-stt"
1635stt_model = "whisper-1"
1636"#,
1637 );
1638 let entry = cfg.stt_provider_entry().expect("should auto-detect");
1640 assert_eq!(entry.effective_name(), "openai-stt");
1641 }
1642
1643 #[test]
1644 fn stt_provider_entry_none_when_no_stt_model() {
1645 let cfg = parse_llm(
1646 r#"
1647[llm]
1648
1649[[llm.providers]]
1650type = "openai"
1651name = "quality"
1652model = "gpt-5.4"
1653"#,
1654 );
1655 assert!(cfg.stt_provider_entry().is_none());
1656 }
1657
1658 #[test]
1659 fn stt_provider_entry_name_mismatch_falls_back_to_none() {
1660 let cfg = parse_llm(
1662 r#"
1663[llm]
1664
1665[[llm.providers]]
1666type = "openai"
1667name = "quality"
1668model = "gpt-5.4"
1669
1670[[llm.providers]]
1671type = "openai"
1672name = "openai-stt"
1673stt_model = "whisper-1"
1674
1675[llm.stt]
1676provider = "quality"
1677"#,
1678 );
1679 assert!(cfg.stt_provider_entry().is_none());
1681 }
1682
1683 #[test]
1684 fn stt_config_deserializes_new_slim_format() {
1685 let cfg = parse_llm(
1686 r#"
1687[llm]
1688
1689[[llm.providers]]
1690type = "openai"
1691name = "quality"
1692stt_model = "whisper-1"
1693
1694[llm.stt]
1695provider = "quality"
1696language = "en"
1697"#,
1698 );
1699 let stt = cfg.stt.as_ref().expect("stt section present");
1700 assert_eq!(stt.provider, "quality");
1701 assert_eq!(stt.language, "en");
1702 }
1703
1704 #[test]
1705 fn stt_config_default_provider_is_empty() {
1706 assert_eq!(default_stt_provider(), "");
1708 }
1709
1710 #[test]
1711 fn validate_stt_missing_provider_ok() {
1712 let cfg = parse_llm("[llm]\n");
1713 assert!(cfg.validate_stt().is_ok());
1714 }
1715
1716 #[test]
1717 fn validate_stt_valid_reference() {
1718 let cfg = parse_llm(
1719 r#"
1720[llm]
1721
1722[[llm.providers]]
1723type = "openai"
1724name = "quality"
1725stt_model = "whisper-1"
1726
1727[llm.stt]
1728provider = "quality"
1729"#,
1730 );
1731 assert!(cfg.validate_stt().is_ok());
1732 }
1733
1734 #[test]
1735 fn validate_stt_nonexistent_provider_errors() {
1736 let cfg = parse_llm(
1737 r#"
1738[llm]
1739
1740[[llm.providers]]
1741type = "openai"
1742name = "quality"
1743model = "gpt-5.4"
1744
1745[llm.stt]
1746provider = "nonexistent"
1747"#,
1748 );
1749 assert!(cfg.validate_stt().is_err());
1750 }
1751
1752 #[test]
1753 fn validate_stt_provider_exists_but_no_stt_model_returns_ok_with_warn() {
1754 let cfg = parse_llm(
1756 r#"
1757[llm]
1758
1759[[llm.providers]]
1760type = "openai"
1761name = "quality"
1762model = "gpt-5.4"
1763
1764[llm.stt]
1765provider = "quality"
1766"#,
1767 );
1768 assert!(cfg.validate_stt().is_ok());
1770 assert!(
1772 cfg.stt_provider_entry().is_none(),
1773 "stt_provider_entry must be None when provider has no stt_model"
1774 );
1775 }
1776
1777 #[test]
1780 fn bandit_warmup_queries_explicit_value_is_deserialized() {
1781 let cfg = parse_llm(
1782 r#"
1783[llm]
1784
1785[llm.router]
1786strategy = "bandit"
1787
1788[llm.router.bandit]
1789warmup_queries = 50
1790"#,
1791 );
1792 let bandit = cfg
1793 .router
1794 .expect("router section must be present")
1795 .bandit
1796 .expect("bandit section must be present");
1797 assert_eq!(
1798 bandit.warmup_queries,
1799 Some(50),
1800 "warmup_queries = 50 must deserialize to Some(50)"
1801 );
1802 }
1803
1804 #[test]
1805 fn bandit_warmup_queries_explicit_null_is_none() {
1806 let cfg = parse_llm(
1809 r#"
1810[llm]
1811
1812[llm.router]
1813strategy = "bandit"
1814
1815[llm.router.bandit]
1816warmup_queries = 0
1817"#,
1818 );
1819 let bandit = cfg
1820 .router
1821 .expect("router section must be present")
1822 .bandit
1823 .expect("bandit section must be present");
1824 assert_eq!(
1826 bandit.warmup_queries,
1827 Some(0),
1828 "warmup_queries = 0 must deserialize to Some(0)"
1829 );
1830 }
1831
1832 #[test]
1833 fn bandit_warmup_queries_missing_field_defaults_to_none() {
1834 let cfg = parse_llm(
1836 r#"
1837[llm]
1838
1839[llm.router]
1840strategy = "bandit"
1841
1842[llm.router.bandit]
1843alpha = 1.5
1844"#,
1845 );
1846 let bandit = cfg
1847 .router
1848 .expect("router section must be present")
1849 .bandit
1850 .expect("bandit section must be present");
1851 assert_eq!(
1852 bandit.warmup_queries, None,
1853 "omitted warmup_queries must default to None"
1854 );
1855 }
1856
1857 #[test]
1858 fn provider_name_new_and_as_str() {
1859 let n = ProviderName::new("fast");
1860 assert_eq!(n.as_str(), "fast");
1861 assert!(!n.is_empty());
1862 }
1863
1864 #[test]
1865 fn provider_name_default_is_empty() {
1866 let n = ProviderName::default();
1867 assert!(n.is_empty());
1868 assert_eq!(n.as_str(), "");
1869 }
1870
1871 #[test]
1872 fn provider_name_deref_to_str() {
1873 let n = ProviderName::new("quality");
1874 let s: &str = &n;
1875 assert_eq!(s, "quality");
1876 }
1877
1878 #[test]
1879 fn provider_name_partial_eq_str() {
1880 let n = ProviderName::new("fast");
1881 assert_eq!(n, "fast");
1882 assert_ne!(n, "slow");
1883 }
1884
1885 #[test]
1886 fn provider_name_serde_roundtrip() {
1887 let n = ProviderName::new("my-provider");
1888 let json = serde_json::to_string(&n).expect("serialize");
1889 assert_eq!(json, "\"my-provider\"");
1890 let back: ProviderName = serde_json::from_str(&json).expect("deserialize");
1891 assert_eq!(back, n);
1892 }
1893
1894 #[test]
1895 fn provider_name_serde_empty_roundtrip() {
1896 let n = ProviderName::default();
1897 let json = serde_json::to_string(&n).expect("serialize");
1898 assert_eq!(json, "\"\"");
1899 let back: ProviderName = serde_json::from_str(&json).expect("deserialize");
1900 assert_eq!(back, n);
1901 assert!(back.is_empty());
1902 }
1903}