use serde::{Deserialize, Serialize};
use zeph_llm::{GeminiThinkingLevel, ThinkingConfig};
fn default_response_cache_ttl_secs() -> u64 {
3600
}
fn default_semantic_cache_threshold() -> f32 {
0.95
}
fn default_semantic_cache_max_candidates() -> u32 {
10
}
fn default_router_ema_alpha() -> f64 {
0.1
}
fn default_router_reorder_interval() -> u64 {
10
}
fn default_embedding_model() -> String {
"qwen3-embedding".into()
}
fn default_candle_source() -> String {
"huggingface".into()
}
fn default_chat_template() -> String {
"chatml".into()
}
fn default_candle_device() -> String {
"cpu".into()
}
fn default_temperature() -> f64 {
0.7
}
fn default_max_tokens() -> usize {
2048
}
fn default_seed() -> u64 {
42
}
fn default_repeat_penalty() -> f32 {
1.1
}
fn default_repeat_last_n() -> usize {
64
}
fn default_cascade_quality_threshold() -> f64 {
0.5
}
fn default_cascade_max_escalations() -> u8 {
2
}
fn default_cascade_window_size() -> usize {
50
}
fn default_reputation_decay_factor() -> f64 {
0.95
}
fn default_reputation_weight() -> f64 {
0.3
}
fn default_reputation_min_observations() -> u64 {
5
}
#[must_use]
pub fn default_stt_provider() -> String {
String::new()
}
#[must_use]
pub fn default_stt_language() -> String {
"auto".into()
}
#[must_use]
pub fn get_default_embedding_model() -> String {
default_embedding_model()
}
#[must_use]
pub fn get_default_response_cache_ttl_secs() -> u64 {
default_response_cache_ttl_secs()
}
#[must_use]
pub fn get_default_router_ema_alpha() -> f64 {
default_router_ema_alpha()
}
#[must_use]
pub fn get_default_router_reorder_interval() -> u64 {
default_router_reorder_interval()
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize, Serialize)]
#[serde(rename_all = "lowercase")]
pub enum ProviderKind {
Ollama,
Claude,
OpenAi,
Gemini,
Candle,
Compatible,
}
impl ProviderKind {
#[must_use]
pub fn as_str(self) -> &'static str {
match self {
Self::Ollama => "ollama",
Self::Claude => "claude",
Self::OpenAi => "openai",
Self::Gemini => "gemini",
Self::Candle => "candle",
Self::Compatible => "compatible",
}
}
}
impl std::fmt::Display for ProviderKind {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(self.as_str())
}
}
#[derive(Debug, Deserialize, Serialize)]
pub struct LlmConfig {
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub providers: Vec<ProviderEntry>,
#[serde(default, skip_serializing_if = "is_routing_none")]
pub routing: LlmRoutingStrategy,
#[serde(default, skip_serializing_if = "std::collections::HashMap::is_empty")]
pub routes: std::collections::HashMap<String, Vec<String>>,
#[serde(default = "default_embedding_model_opt")]
pub embedding_model: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub candle: Option<CandleConfig>,
#[serde(default)]
pub stt: Option<SttConfig>,
#[serde(default)]
pub response_cache_enabled: bool,
#[serde(default = "default_response_cache_ttl_secs")]
pub response_cache_ttl_secs: u64,
#[serde(default)]
pub semantic_cache_enabled: bool,
#[serde(default = "default_semantic_cache_threshold")]
pub semantic_cache_threshold: f32,
#[serde(default = "default_semantic_cache_max_candidates")]
pub semantic_cache_max_candidates: u32,
#[serde(default)]
pub router_ema_enabled: bool,
#[serde(default = "default_router_ema_alpha")]
pub router_ema_alpha: f64,
#[serde(default = "default_router_reorder_interval")]
pub router_reorder_interval: u64,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub router: Option<RouterConfig>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub instruction_file: Option<std::path::PathBuf>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub summary_model: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub summary_provider: Option<ProviderEntry>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub complexity_routing: Option<ComplexityRoutingConfig>,
}
fn default_embedding_model_opt() -> String {
default_embedding_model()
}
#[allow(clippy::trivially_copy_pass_by_ref)]
fn is_routing_none(s: &LlmRoutingStrategy) -> bool {
*s == LlmRoutingStrategy::None
}
impl LlmConfig {
#[must_use]
pub fn effective_provider(&self) -> ProviderKind {
self.providers
.first()
.map_or(ProviderKind::Ollama, |e| e.provider_type)
}
#[must_use]
pub fn effective_base_url(&self) -> &str {
self.providers
.first()
.and_then(|e| e.base_url.as_deref())
.unwrap_or("http://localhost:11434")
}
#[must_use]
pub fn effective_model(&self) -> &str {
self.providers
.first()
.and_then(|e| e.model.as_deref())
.unwrap_or("qwen3:8b")
}
#[must_use]
pub fn stt_provider_entry(&self) -> Option<&ProviderEntry> {
let name_hint = self.stt.as_ref().map_or("", |s| s.provider.as_str());
if name_hint.is_empty() {
self.providers.iter().find(|p| p.stt_model.is_some())
} else {
self.providers
.iter()
.find(|p| p.effective_name() == name_hint && p.stt_model.is_some())
}
}
pub fn check_legacy_format(&self) -> Result<(), crate::error::ConfigError> {
Ok(())
}
pub fn validate_stt(&self) -> Result<(), crate::error::ConfigError> {
use crate::error::ConfigError;
let Some(stt) = &self.stt else {
return Ok(());
};
if stt.provider.is_empty() {
return Ok(());
}
let found = self
.providers
.iter()
.find(|p| p.effective_name() == stt.provider);
match found {
None => {
return Err(ConfigError::Validation(format!(
"[llm.stt].provider = {:?} does not match any [[llm.providers]] entry",
stt.provider
)));
}
Some(entry) if entry.stt_model.is_none() => {
tracing::warn!(
provider = stt.provider,
"[[llm.providers]] entry exists but has no `stt_model` — STT will not be activated"
);
}
_ => {}
}
Ok(())
}
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct SttConfig {
#[serde(default = "default_stt_provider")]
pub provider: String,
#[serde(default = "default_stt_language")]
pub language: String,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Deserialize, Serialize)]
#[serde(rename_all = "lowercase")]
pub enum RouterStrategyConfig {
#[default]
Ema,
Thompson,
Cascade,
Bandit,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct RouterConfig {
#[serde(default)]
pub strategy: RouterStrategyConfig,
#[serde(default)]
pub thompson_state_path: Option<String>,
#[serde(default)]
pub cascade: Option<CascadeConfig>,
#[serde(default)]
pub reputation: Option<ReputationConfig>,
#[serde(default)]
pub bandit: Option<BanditConfig>,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ReputationConfig {
#[serde(default)]
pub enabled: bool,
#[serde(default = "default_reputation_decay_factor")]
pub decay_factor: f64,
#[serde(default = "default_reputation_weight")]
pub weight: f64,
#[serde(default = "default_reputation_min_observations")]
pub min_observations: u64,
#[serde(default)]
pub state_path: Option<String>,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct CascadeConfig {
#[serde(default = "default_cascade_quality_threshold")]
pub quality_threshold: f64,
#[serde(default = "default_cascade_max_escalations")]
pub max_escalations: u8,
#[serde(default)]
pub classifier_mode: CascadeClassifierMode,
#[serde(default = "default_cascade_window_size")]
pub window_size: usize,
#[serde(default)]
pub max_cascade_tokens: Option<u32>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub cost_tiers: Option<Vec<String>>,
}
impl Default for CascadeConfig {
fn default() -> Self {
Self {
quality_threshold: default_cascade_quality_threshold(),
max_escalations: default_cascade_max_escalations(),
classifier_mode: CascadeClassifierMode::default(),
window_size: default_cascade_window_size(),
max_cascade_tokens: None,
cost_tiers: None,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Deserialize, Serialize)]
#[serde(rename_all = "lowercase")]
pub enum CascadeClassifierMode {
#[default]
Heuristic,
Judge,
}
fn default_bandit_alpha() -> f32 {
1.0
}
fn default_bandit_dim() -> usize {
32
}
fn default_bandit_cost_weight() -> f32 {
0.1
}
fn default_bandit_decay_factor() -> f32 {
1.0
}
fn default_bandit_embedding_timeout_ms() -> u64 {
50
}
fn default_bandit_cache_size() -> usize {
512
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct BanditConfig {
#[serde(default = "default_bandit_alpha")]
pub alpha: f32,
#[serde(default = "default_bandit_dim")]
pub dim: usize,
#[serde(default = "default_bandit_cost_weight")]
pub cost_weight: f32,
#[serde(default = "default_bandit_decay_factor")]
pub decay_factor: f32,
#[serde(default)]
pub embedding_provider: String,
#[serde(default = "default_bandit_embedding_timeout_ms")]
pub embedding_timeout_ms: u64,
#[serde(default = "default_bandit_cache_size")]
pub cache_size: usize,
#[serde(default)]
pub state_path: Option<String>,
#[serde(default = "default_bandit_memory_confidence_threshold")]
pub memory_confidence_threshold: f32,
}
fn default_bandit_memory_confidence_threshold() -> f32 {
0.9
}
impl Default for BanditConfig {
fn default() -> Self {
Self {
alpha: default_bandit_alpha(),
dim: default_bandit_dim(),
cost_weight: default_bandit_cost_weight(),
decay_factor: default_bandit_decay_factor(),
embedding_provider: String::new(),
embedding_timeout_ms: default_bandit_embedding_timeout_ms(),
cache_size: default_bandit_cache_size(),
state_path: None,
memory_confidence_threshold: default_bandit_memory_confidence_threshold(),
}
}
}
#[derive(Debug, Deserialize, Serialize)]
pub struct CandleConfig {
#[serde(default = "default_candle_source")]
pub source: String,
#[serde(default)]
pub local_path: String,
#[serde(default)]
pub filename: Option<String>,
#[serde(default = "default_chat_template")]
pub chat_template: String,
#[serde(default = "default_candle_device")]
pub device: String,
#[serde(default)]
pub embedding_repo: Option<String>,
#[serde(default)]
pub hf_token: Option<String>,
#[serde(default)]
pub generation: GenerationParams,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct GenerationParams {
#[serde(default = "default_temperature")]
pub temperature: f64,
#[serde(default)]
pub top_p: Option<f64>,
#[serde(default)]
pub top_k: Option<usize>,
#[serde(default = "default_max_tokens")]
pub max_tokens: usize,
#[serde(default = "default_seed")]
pub seed: u64,
#[serde(default = "default_repeat_penalty")]
pub repeat_penalty: f32,
#[serde(default = "default_repeat_last_n")]
pub repeat_last_n: usize,
}
pub const MAX_TOKENS_CAP: usize = 32768;
impl GenerationParams {
#[must_use]
pub fn capped_max_tokens(&self) -> usize {
self.max_tokens.min(MAX_TOKENS_CAP)
}
}
impl Default for GenerationParams {
fn default() -> Self {
Self {
temperature: default_temperature(),
top_p: None,
top_k: None,
max_tokens: default_max_tokens(),
seed: default_seed(),
repeat_penalty: default_repeat_penalty(),
repeat_last_n: default_repeat_last_n(),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Deserialize, Serialize)]
#[serde(rename_all = "lowercase")]
pub enum LlmRoutingStrategy {
#[default]
None,
Ema,
Thompson,
Cascade,
Task,
Triage,
Bandit,
}
fn default_triage_timeout_secs() -> u64 {
5
}
fn default_max_triage_tokens() -> u32 {
50
}
fn default_true() -> bool {
true
}
#[derive(Debug, Clone, Default, Deserialize, Serialize)]
pub struct TierMapping {
pub simple: Option<String>,
pub medium: Option<String>,
pub complex: Option<String>,
pub expert: Option<String>,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ComplexityRoutingConfig {
#[serde(default)]
pub triage_provider: Option<String>,
#[serde(default = "default_true")]
pub bypass_single_provider: bool,
#[serde(default)]
pub tiers: TierMapping,
#[serde(default = "default_max_triage_tokens")]
pub max_triage_tokens: u32,
#[serde(default = "default_triage_timeout_secs")]
pub triage_timeout_secs: u64,
#[serde(default)]
pub fallback_strategy: Option<String>,
}
impl Default for ComplexityRoutingConfig {
fn default() -> Self {
Self {
triage_provider: None,
bypass_single_provider: true,
tiers: TierMapping::default(),
max_triage_tokens: default_max_triage_tokens(),
triage_timeout_secs: default_triage_timeout_secs(),
fallback_strategy: None,
}
}
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct CandleInlineConfig {
#[serde(default = "default_candle_source")]
pub source: String,
#[serde(default)]
pub local_path: String,
#[serde(default)]
pub filename: Option<String>,
#[serde(default = "default_chat_template")]
pub chat_template: String,
#[serde(default = "default_candle_device")]
pub device: String,
#[serde(default)]
pub embedding_repo: Option<String>,
#[serde(default)]
pub hf_token: Option<String>,
#[serde(default)]
pub generation: GenerationParams,
}
impl Default for CandleInlineConfig {
fn default() -> Self {
Self {
source: default_candle_source(),
local_path: String::new(),
filename: None,
chat_template: default_chat_template(),
device: default_candle_device(),
embedding_repo: None,
hf_token: None,
generation: GenerationParams::default(),
}
}
}
#[derive(Debug, Clone, Deserialize, Serialize)]
#[allow(clippy::struct_excessive_bools)]
pub struct ProviderEntry {
#[serde(rename = "type")]
pub provider_type: ProviderKind,
#[serde(default)]
pub name: Option<String>,
#[serde(default)]
pub model: Option<String>,
#[serde(default)]
pub base_url: Option<String>,
#[serde(default)]
pub max_tokens: Option<u32>,
#[serde(default)]
pub embedding_model: Option<String>,
#[serde(default)]
pub stt_model: Option<String>,
#[serde(default)]
pub embed: bool,
#[serde(default)]
pub default: bool,
#[serde(default)]
pub thinking: Option<ThinkingConfig>,
#[serde(default)]
pub server_compaction: bool,
#[serde(default)]
pub enable_extended_context: bool,
#[serde(default)]
pub reasoning_effort: Option<String>,
#[serde(default)]
pub thinking_level: Option<GeminiThinkingLevel>,
#[serde(default)]
pub thinking_budget: Option<i32>,
#[serde(default)]
pub include_thoughts: Option<bool>,
#[serde(default)]
pub tool_use: bool,
#[serde(default)]
pub api_key: Option<String>,
#[serde(default)]
pub candle: Option<CandleInlineConfig>,
#[serde(default)]
pub vision_model: Option<String>,
#[serde(default)]
pub instruction_file: Option<std::path::PathBuf>,
}
impl Default for ProviderEntry {
fn default() -> Self {
Self {
provider_type: ProviderKind::Ollama,
name: None,
model: None,
base_url: None,
max_tokens: None,
embedding_model: None,
stt_model: None,
embed: false,
default: false,
thinking: None,
server_compaction: false,
enable_extended_context: false,
reasoning_effort: None,
thinking_level: None,
thinking_budget: None,
include_thoughts: None,
tool_use: false,
api_key: None,
candle: None,
vision_model: None,
instruction_file: None,
}
}
}
impl ProviderEntry {
#[must_use]
pub fn effective_name(&self) -> String {
self.name
.clone()
.unwrap_or_else(|| self.provider_type.as_str().to_owned())
}
#[must_use]
pub fn effective_model(&self) -> String {
if let Some(ref m) = self.model {
return m.clone();
}
match self.provider_type {
ProviderKind::Ollama => "qwen3:8b".to_owned(),
ProviderKind::Claude => "claude-haiku-4-5-20251001".to_owned(),
ProviderKind::OpenAi => "gpt-4o-mini".to_owned(),
ProviderKind::Gemini => "gemini-2.0-flash".to_owned(),
ProviderKind::Compatible | ProviderKind::Candle => String::new(),
}
}
pub fn validate(&self) -> Result<(), crate::error::ConfigError> {
use crate::error::ConfigError;
if self.provider_type == ProviderKind::Compatible && self.name.is_none() {
return Err(ConfigError::Validation(
"[[llm.providers]] entry with type=\"compatible\" must set `name`".into(),
));
}
match self.provider_type {
ProviderKind::Ollama => {
if self.thinking.is_some() {
tracing::warn!(
provider = self.effective_name(),
"field `thinking` is only used by Claude providers"
);
}
if self.reasoning_effort.is_some() {
tracing::warn!(
provider = self.effective_name(),
"field `reasoning_effort` is only used by OpenAI providers"
);
}
if self.thinking_level.is_some() || self.thinking_budget.is_some() {
tracing::warn!(
provider = self.effective_name(),
"fields `thinking_level`/`thinking_budget` are only used by Gemini providers"
);
}
}
ProviderKind::Claude => {
if self.reasoning_effort.is_some() {
tracing::warn!(
provider = self.effective_name(),
"field `reasoning_effort` is only used by OpenAI providers"
);
}
if self.thinking_level.is_some() || self.thinking_budget.is_some() {
tracing::warn!(
provider = self.effective_name(),
"fields `thinking_level`/`thinking_budget` are only used by Gemini providers"
);
}
if self.tool_use {
tracing::warn!(
provider = self.effective_name(),
"field `tool_use` is only used by Ollama providers"
);
}
}
ProviderKind::OpenAi => {
if self.thinking.is_some() {
tracing::warn!(
provider = self.effective_name(),
"field `thinking` is only used by Claude providers"
);
}
if self.thinking_level.is_some() || self.thinking_budget.is_some() {
tracing::warn!(
provider = self.effective_name(),
"fields `thinking_level`/`thinking_budget` are only used by Gemini providers"
);
}
if self.tool_use {
tracing::warn!(
provider = self.effective_name(),
"field `tool_use` is only used by Ollama providers"
);
}
}
ProviderKind::Gemini => {
if self.thinking.is_some() {
tracing::warn!(
provider = self.effective_name(),
"field `thinking` is only used by Claude providers"
);
}
if self.reasoning_effort.is_some() {
tracing::warn!(
provider = self.effective_name(),
"field `reasoning_effort` is only used by OpenAI providers"
);
}
if self.tool_use {
tracing::warn!(
provider = self.effective_name(),
"field `tool_use` is only used by Ollama providers"
);
}
}
_ => {}
}
if self.stt_model.is_some() && self.provider_type == ProviderKind::Ollama {
tracing::warn!(
provider = self.effective_name(),
"field `stt_model` is set on an Ollama provider; Ollama does not support the \
Whisper STT API — use OpenAI, compatible, or candle instead"
);
}
Ok(())
}
}
pub fn validate_pool(entries: &[ProviderEntry]) -> Result<(), crate::error::ConfigError> {
use crate::error::ConfigError;
use std::collections::HashSet;
if entries.is_empty() {
return Err(ConfigError::Validation(
"at least one LLM provider must be configured in [[llm.providers]]".into(),
));
}
let default_count = entries.iter().filter(|e| e.default).count();
if default_count > 1 {
return Err(ConfigError::Validation(
"only one [[llm.providers]] entry can be marked `default = true`".into(),
));
}
let mut seen_names: HashSet<String> = HashSet::new();
for entry in entries {
let name = entry.effective_name();
if !seen_names.insert(name.clone()) {
return Err(ConfigError::Validation(format!(
"duplicate provider name \"{name}\" in [[llm.providers]]"
)));
}
entry.validate()?;
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
fn ollama_entry() -> ProviderEntry {
ProviderEntry {
provider_type: ProviderKind::Ollama,
name: Some("ollama".into()),
model: Some("qwen3:8b".into()),
..Default::default()
}
}
fn claude_entry() -> ProviderEntry {
ProviderEntry {
provider_type: ProviderKind::Claude,
name: Some("claude".into()),
model: Some("claude-sonnet-4-6".into()),
max_tokens: Some(8192),
..Default::default()
}
}
#[test]
fn validate_ollama_valid() {
assert!(ollama_entry().validate().is_ok());
}
#[test]
fn validate_claude_valid() {
assert!(claude_entry().validate().is_ok());
}
#[test]
fn validate_compatible_without_name_errors() {
let entry = ProviderEntry {
provider_type: ProviderKind::Compatible,
name: None,
..Default::default()
};
let err = entry.validate().unwrap_err();
assert!(
err.to_string().contains("compatible"),
"error should mention compatible: {err}"
);
}
#[test]
fn validate_compatible_with_name_ok() {
let entry = ProviderEntry {
provider_type: ProviderKind::Compatible,
name: Some("my-proxy".into()),
base_url: Some("http://localhost:8080".into()),
model: Some("gpt-4o".into()),
max_tokens: Some(4096),
..Default::default()
};
assert!(entry.validate().is_ok());
}
#[test]
fn validate_openai_valid() {
let entry = ProviderEntry {
provider_type: ProviderKind::OpenAi,
name: Some("openai".into()),
model: Some("gpt-4o".into()),
max_tokens: Some(4096),
..Default::default()
};
assert!(entry.validate().is_ok());
}
#[test]
fn validate_gemini_valid() {
let entry = ProviderEntry {
provider_type: ProviderKind::Gemini,
name: Some("gemini".into()),
model: Some("gemini-2.0-flash".into()),
..Default::default()
};
assert!(entry.validate().is_ok());
}
#[test]
fn validate_pool_empty_errors() {
let err = validate_pool(&[]).unwrap_err();
assert!(err.to_string().contains("at least one"), "{err}");
}
#[test]
fn validate_pool_single_entry_ok() {
assert!(validate_pool(&[ollama_entry()]).is_ok());
}
#[test]
fn validate_pool_duplicate_names_errors() {
let a = ollama_entry();
let b = ollama_entry(); let err = validate_pool(&[a, b]).unwrap_err();
assert!(err.to_string().contains("duplicate"), "{err}");
}
#[test]
fn validate_pool_multiple_defaults_errors() {
let mut a = ollama_entry();
let mut b = claude_entry();
a.default = true;
b.default = true;
let err = validate_pool(&[a, b]).unwrap_err();
assert!(err.to_string().contains("default"), "{err}");
}
#[test]
fn validate_pool_two_different_providers_ok() {
assert!(validate_pool(&[ollama_entry(), claude_entry()]).is_ok());
}
#[test]
fn validate_pool_propagates_entry_error() {
let bad = ProviderEntry {
provider_type: ProviderKind::Compatible,
name: None, ..Default::default()
};
assert!(validate_pool(&[bad]).is_err());
}
#[test]
fn effective_model_returns_explicit_when_set() {
let entry = ProviderEntry {
provider_type: ProviderKind::Claude,
model: Some("claude-sonnet-4-6".into()),
..Default::default()
};
assert_eq!(entry.effective_model(), "claude-sonnet-4-6");
}
#[test]
fn effective_model_ollama_default_when_none() {
let entry = ProviderEntry {
provider_type: ProviderKind::Ollama,
model: None,
..Default::default()
};
assert_eq!(entry.effective_model(), "qwen3:8b");
}
#[test]
fn effective_model_claude_default_when_none() {
let entry = ProviderEntry {
provider_type: ProviderKind::Claude,
model: None,
..Default::default()
};
assert_eq!(entry.effective_model(), "claude-haiku-4-5-20251001");
}
#[test]
fn effective_model_openai_default_when_none() {
let entry = ProviderEntry {
provider_type: ProviderKind::OpenAi,
model: None,
..Default::default()
};
assert_eq!(entry.effective_model(), "gpt-4o-mini");
}
#[test]
fn effective_model_gemini_default_when_none() {
let entry = ProviderEntry {
provider_type: ProviderKind::Gemini,
model: None,
..Default::default()
};
assert_eq!(entry.effective_model(), "gemini-2.0-flash");
}
fn parse_llm(toml: &str) -> LlmConfig {
#[derive(serde::Deserialize)]
struct Wrapper {
llm: LlmConfig,
}
toml::from_str::<Wrapper>(toml).unwrap().llm
}
#[test]
fn check_legacy_format_new_format_ok() {
let cfg = parse_llm(
r#"
[llm]
[[llm.providers]]
type = "ollama"
model = "qwen3:8b"
"#,
);
assert!(cfg.check_legacy_format().is_ok());
}
#[test]
fn check_legacy_format_empty_providers_no_legacy_ok() {
let cfg = parse_llm("[llm]\n");
assert!(cfg.check_legacy_format().is_ok());
}
#[test]
fn effective_provider_falls_back_to_ollama_when_no_providers() {
let cfg = parse_llm("[llm]\n");
assert_eq!(cfg.effective_provider(), ProviderKind::Ollama);
}
#[test]
fn effective_provider_reads_from_providers_first() {
let cfg = parse_llm(
r#"
[llm]
[[llm.providers]]
type = "claude"
model = "claude-sonnet-4-6"
"#,
);
assert_eq!(cfg.effective_provider(), ProviderKind::Claude);
}
#[test]
fn effective_model_reads_from_providers_first() {
let cfg = parse_llm(
r#"
[llm]
[[llm.providers]]
type = "ollama"
model = "qwen3:8b"
"#,
);
assert_eq!(cfg.effective_model(), "qwen3:8b");
}
#[test]
fn effective_base_url_default_when_absent() {
let cfg = parse_llm("[llm]\n");
assert_eq!(cfg.effective_base_url(), "http://localhost:11434");
}
#[test]
fn effective_base_url_from_providers_entry() {
let cfg = parse_llm(
r#"
[llm]
[[llm.providers]]
type = "ollama"
base_url = "http://myhost:11434"
"#,
);
assert_eq!(cfg.effective_base_url(), "http://myhost:11434");
}
#[test]
fn complexity_routing_defaults() {
let cr = ComplexityRoutingConfig::default();
assert!(
cr.bypass_single_provider,
"bypass_single_provider must default to true"
);
assert_eq!(cr.triage_timeout_secs, 5);
assert_eq!(cr.max_triage_tokens, 50);
assert!(cr.triage_provider.is_none());
assert!(cr.tiers.simple.is_none());
}
#[test]
fn complexity_routing_toml_round_trip() {
let cfg = parse_llm(
r#"
[llm]
routing = "triage"
[llm.complexity_routing]
triage_provider = "fast"
bypass_single_provider = false
triage_timeout_secs = 10
max_triage_tokens = 100
[llm.complexity_routing.tiers]
simple = "fast"
medium = "medium"
complex = "large"
expert = "opus"
"#,
);
assert!(matches!(cfg.routing, LlmRoutingStrategy::Triage));
let cr = cfg
.complexity_routing
.expect("complexity_routing must be present");
assert_eq!(cr.triage_provider.as_deref(), Some("fast"));
assert!(!cr.bypass_single_provider);
assert_eq!(cr.triage_timeout_secs, 10);
assert_eq!(cr.max_triage_tokens, 100);
assert_eq!(cr.tiers.simple.as_deref(), Some("fast"));
assert_eq!(cr.tiers.medium.as_deref(), Some("medium"));
assert_eq!(cr.tiers.complex.as_deref(), Some("large"));
assert_eq!(cr.tiers.expert.as_deref(), Some("opus"));
}
#[test]
fn complexity_routing_partial_tiers_toml() {
let cfg = parse_llm(
r#"
[llm]
routing = "triage"
[llm.complexity_routing.tiers]
simple = "haiku"
complex = "sonnet"
"#,
);
let cr = cfg
.complexity_routing
.expect("complexity_routing must be present");
assert_eq!(cr.tiers.simple.as_deref(), Some("haiku"));
assert!(cr.tiers.medium.is_none());
assert_eq!(cr.tiers.complex.as_deref(), Some("sonnet"));
assert!(cr.tiers.expert.is_none());
assert!(cr.bypass_single_provider);
assert_eq!(cr.triage_timeout_secs, 5);
}
#[test]
fn routing_strategy_triage_deserialized() {
let cfg = parse_llm(
r#"
[llm]
routing = "triage"
"#,
);
assert!(matches!(cfg.routing, LlmRoutingStrategy::Triage));
}
#[test]
fn stt_provider_entry_by_name_match() {
let cfg = parse_llm(
r#"
[llm]
[[llm.providers]]
type = "openai"
name = "quality"
model = "gpt-5.4"
stt_model = "gpt-4o-mini-transcribe"
[llm.stt]
provider = "quality"
"#,
);
let entry = cfg.stt_provider_entry().expect("should find stt provider");
assert_eq!(entry.effective_name(), "quality");
assert_eq!(entry.stt_model.as_deref(), Some("gpt-4o-mini-transcribe"));
}
#[test]
fn stt_provider_entry_auto_detect_when_provider_empty() {
let cfg = parse_llm(
r#"
[llm]
[[llm.providers]]
type = "openai"
name = "openai-stt"
stt_model = "whisper-1"
[llm.stt]
provider = ""
"#,
);
let entry = cfg.stt_provider_entry().expect("should auto-detect");
assert_eq!(entry.effective_name(), "openai-stt");
}
#[test]
fn stt_provider_entry_auto_detect_no_stt_section() {
let cfg = parse_llm(
r#"
[llm]
[[llm.providers]]
type = "openai"
name = "openai-stt"
stt_model = "whisper-1"
"#,
);
let entry = cfg.stt_provider_entry().expect("should auto-detect");
assert_eq!(entry.effective_name(), "openai-stt");
}
#[test]
fn stt_provider_entry_none_when_no_stt_model() {
let cfg = parse_llm(
r#"
[llm]
[[llm.providers]]
type = "openai"
name = "quality"
model = "gpt-5.4"
"#,
);
assert!(cfg.stt_provider_entry().is_none());
}
#[test]
fn stt_provider_entry_name_mismatch_falls_back_to_none() {
let cfg = parse_llm(
r#"
[llm]
[[llm.providers]]
type = "openai"
name = "quality"
model = "gpt-5.4"
[[llm.providers]]
type = "openai"
name = "openai-stt"
stt_model = "whisper-1"
[llm.stt]
provider = "quality"
"#,
);
assert!(cfg.stt_provider_entry().is_none());
}
#[test]
fn stt_config_deserializes_new_slim_format() {
let cfg = parse_llm(
r#"
[llm]
[[llm.providers]]
type = "openai"
name = "quality"
stt_model = "whisper-1"
[llm.stt]
provider = "quality"
language = "en"
"#,
);
let stt = cfg.stt.as_ref().expect("stt section present");
assert_eq!(stt.provider, "quality");
assert_eq!(stt.language, "en");
}
#[test]
fn stt_config_default_provider_is_empty() {
assert_eq!(default_stt_provider(), "");
}
#[test]
fn validate_stt_missing_provider_ok() {
let cfg = parse_llm("[llm]\n");
assert!(cfg.validate_stt().is_ok());
}
#[test]
fn validate_stt_valid_reference() {
let cfg = parse_llm(
r#"
[llm]
[[llm.providers]]
type = "openai"
name = "quality"
stt_model = "whisper-1"
[llm.stt]
provider = "quality"
"#,
);
assert!(cfg.validate_stt().is_ok());
}
#[test]
fn validate_stt_nonexistent_provider_errors() {
let cfg = parse_llm(
r#"
[llm]
[[llm.providers]]
type = "openai"
name = "quality"
model = "gpt-5.4"
[llm.stt]
provider = "nonexistent"
"#,
);
assert!(cfg.validate_stt().is_err());
}
#[test]
fn validate_stt_provider_exists_but_no_stt_model_returns_ok_with_warn() {
let cfg = parse_llm(
r#"
[llm]
[[llm.providers]]
type = "openai"
name = "quality"
model = "gpt-5.4"
[llm.stt]
provider = "quality"
"#,
);
assert!(cfg.validate_stt().is_ok());
assert!(
cfg.stt_provider_entry().is_none(),
"stt_provider_entry must be None when provider has no stt_model"
);
}
}