use std::fmt::Write as _;
use super::Agent;
use crate::channel::Channel;
use zeph_llm::provider::LlmProvider as _;
impl<C: Channel> Agent<C> {
fn update_provider_instructions(&mut self, entry: &zeph_config::ProviderEntry) {
let Some(ref mut reload_state) = self.instructions.reload_state else {
return;
};
reload_state.provider_kinds = vec![entry.provider_type];
if let Some(ref path) = entry.instruction_file
&& !reload_state.explicit_files.contains(path)
{
reload_state.explicit_files.push(path.clone());
}
let base_dir = reload_state.base_dir.clone();
let provider_kinds = reload_state.provider_kinds.clone();
let explicit_files = reload_state.explicit_files.clone();
let auto_detect = reload_state.auto_detect;
let new_blocks = crate::instructions::load_instructions(
&base_dir,
&provider_kinds,
&explicit_files,
auto_detect,
);
tracing::info!(
count = new_blocks.len(),
provider = ?entry.provider_type,
"reloaded instruction files after provider switch"
);
self.instructions.blocks = new_blocks;
}
fn apply_provider_switch_metrics(
&mut self,
entry: &zeph_config::ProviderEntry,
configured_name: &str,
) {
#[allow(clippy::cast_possible_truncation)]
let provider_temperature = entry
.candle
.as_ref()
.map(|c| c.generation.temperature as f32);
#[allow(clippy::cast_possible_truncation)]
let provider_top_p = entry
.candle
.as_ref()
.and_then(|c| c.generation.top_p.map(|v| v as f32));
let switched_model = self.runtime.model_name.clone();
let name = configured_name.to_owned();
self.update_metrics(|m| {
m.provider_name.clone_from(&name);
m.model_name = switched_model;
m.provider_temperature = provider_temperature;
m.provider_top_p = provider_top_p;
});
}
pub(super) fn handle_provider_command_as_string(&mut self, arg: &str) -> String {
match arg {
"" => self.provider_list_as_string(),
"status" => self.provider_status_as_string(),
name => self.provider_switch_as_string(name),
}
}
fn provider_list_as_string(&self) -> String {
let pool = &self.providers.provider_pool;
if pool.is_empty() {
return "No providers configured in [[llm.providers]].".to_owned();
}
let current = if self.runtime.active_provider_name.is_empty() {
self.provider.name().to_owned()
} else {
self.runtime.active_provider_name.clone()
};
let mut lines = vec!["Configured providers:".to_string()];
for (i, entry) in pool.iter().enumerate() {
let name = entry.effective_name();
let model = entry.model.as_deref().unwrap_or("(default)");
let marker = if name.eq_ignore_ascii_case(¤t) {
" (active)"
} else {
""
};
lines.push(format!(
" {}. {} [{}] model={}{}",
i + 1,
name,
entry.provider_type,
model,
marker
));
}
lines.join("\n")
}
fn provider_status_as_string(&self) -> String {
let mut out = String::from("Current provider:\n\n");
let display_name = if self.runtime.active_provider_name.is_empty() {
self.provider.name().to_owned()
} else {
self.runtime.active_provider_name.clone()
};
let _ = writeln!(out, "Name: {display_name}");
let _ = writeln!(out, "Model: {}", self.runtime.model_name);
if let Some(ref tx) = self.metrics.metrics_tx {
let m = tx.borrow();
let _ = writeln!(out, "API calls: {}", m.api_calls);
let _ = writeln!(
out,
"Tokens: {} prompt / {} completion",
m.prompt_tokens, m.completion_tokens
);
if m.cost_spent_cents > 0.0 {
let _ = writeln!(out, "Cost: ${:.4}", m.cost_spent_cents / 100.0);
}
}
out.trim_end().to_owned()
}
fn provider_switch_as_string(&mut self, name: &str) -> String {
let entry_clone = self
.providers
.provider_pool
.iter()
.find(|e| e.effective_name().eq_ignore_ascii_case(name))
.cloned();
let Some(entry) = entry_clone else {
let names: Vec<_> = self
.providers
.provider_pool
.iter()
.map(zeph_config::ProviderEntry::effective_name)
.collect();
return format!(
"Unknown provider '{}'. Available: {}",
name,
names.join(", ")
);
};
let current_name = if self.runtime.active_provider_name.is_empty() {
self.provider.name().to_owned()
} else {
self.runtime.active_provider_name.clone()
};
if current_name.eq_ignore_ascii_case(name) {
return format!("Provider '{current_name}' is already active.");
}
let Some(ref snapshot) = self.providers.provider_config_snapshot else {
return "Provider switching unavailable (config snapshot missing).".to_owned();
};
match crate::provider_factory::build_provider_for_switch(&entry, snapshot) {
Ok(new_provider) => {
let model_name = entry.effective_model();
let configured_name = entry.effective_name();
self.provider = new_provider;
self.runtime.model_name.clone_from(&model_name);
self.runtime
.active_provider_name
.clone_from(&configured_name);
self.providers.cached_prompt_tokens = 0;
self.providers.server_compaction_active = entry.server_compaction;
self.metrics.extended_context = entry.enable_extended_context;
tracing::info!(
provider = configured_name,
model = model_name,
"provider switched via /provider command"
);
if let Some(ref override_slot) = self.providers.provider_override {
*override_slot.write() = None;
}
self.update_provider_instructions(&entry);
self.apply_provider_switch_metrics(&entry, &configured_name);
self.build_switch_message(&configured_name)
}
Err(e) => format!("Failed to switch to '{name}': {e}"),
}
}
fn build_switch_message(&self, configured_name: &str) -> String {
let embed_name = self.embedding_provider.name();
if embed_name.eq_ignore_ascii_case(configured_name) {
format!(
"Switched to provider: {} (model: {})",
configured_name, self.runtime.model_name
)
} else {
tracing::info!(
embedding_provider = embed_name,
"embedding operations continue using provider '{embed_name}'"
);
format!(
"Switched to provider: {} (model: {}). Embedding operations continue using \
provider '{}'.",
configured_name, self.runtime.model_name, embed_name
)
}
}
}
#[cfg(test)]
mod tests {
use std::collections::HashMap;
use crate::agent::Agent;
use crate::agent::state::ProviderConfigSnapshot;
use crate::agent::tests::agent_tests::{
MockChannel, MockToolExecutor, QuickTestAgent, create_test_registry, mock_provider,
};
use zeph_config::{ProviderEntry, ProviderKind};
use zeph_llm::provider::LlmProvider as _;
fn make_entry(name: &str, kind: ProviderKind, model: Option<&str>) -> ProviderEntry {
ProviderEntry {
name: Some(name.to_owned()),
provider_type: kind,
model: model.map(str::to_owned),
..ProviderEntry::default()
}
}
fn ollama_snapshot() -> ProviderConfigSnapshot {
ProviderConfigSnapshot {
claude_api_key: None,
openai_api_key: None,
gemini_api_key: None,
compatible_api_keys: HashMap::default(),
llm_request_timeout_secs: 30,
embedding_model: "nomic-embed-text".to_owned(),
}
}
#[test]
fn provider_list_empty_pool() {
let mut qa = QuickTestAgent::minimal("ok");
let out = qa.agent.handle_provider_command_as_string("");
assert!(out.contains("No providers configured"));
}
#[test]
fn provider_list_shows_all_with_active_marker() {
let provider = mock_provider(vec![]);
let channel = MockChannel::new(vec![]);
let registry = create_test_registry();
let executor = MockToolExecutor::no_tools();
let mut agent = Agent::new(provider, channel, registry, None, 5, executor);
let entry_a = make_entry("ollama", ProviderKind::Ollama, Some("qwen3:8b"));
let entry_b = make_entry(
"claude",
ProviderKind::Claude,
Some("claude-haiku-4-5-20251001"),
);
agent.providers.provider_pool = vec![entry_a, entry_b];
let out = agent.handle_provider_command_as_string("");
assert!(out.contains("ollama"), "should list ollama");
assert!(out.contains("claude"), "should list claude");
assert!(out.contains("Configured providers:"));
}
#[test]
fn provider_list_marks_active_provider() {
let channel = MockChannel::new(vec![]);
let registry = create_test_registry();
let executor = MockToolExecutor::no_tools();
let entry = make_entry("ollama", ProviderKind::Ollama, Some("qwen3:8b"));
let snapshot = ollama_snapshot();
let new_provider =
crate::provider_factory::build_provider_for_switch(&entry, &snapshot).unwrap();
let mut agent = Agent::new(new_provider, channel, registry, None, 5, executor);
agent.providers.provider_pool = vec![entry];
agent.providers.provider_config_snapshot = Some(snapshot);
let out = agent.handle_provider_command_as_string("");
assert!(out.contains("(active)"), "active entry must be marked");
}
#[test]
fn provider_switch_unknown_name_returns_error() {
let mut qa = QuickTestAgent::minimal("ok");
let entry = make_entry("ollama", ProviderKind::Ollama, Some("qwen3:8b"));
qa.agent.providers.provider_pool = vec![entry];
let out = qa.agent.handle_provider_command_as_string("nonexistent");
assert!(out.contains("Unknown provider 'nonexistent'"));
assert!(out.contains("ollama"));
}
#[test]
fn provider_switch_already_active_warns() {
let entry = make_entry("ollama", ProviderKind::Ollama, Some("qwen3:8b"));
let snapshot = ollama_snapshot();
let provider =
crate::provider_factory::build_provider_for_switch(&entry, &snapshot).unwrap();
let channel = MockChannel::new(vec![]);
let registry = create_test_registry();
let executor = MockToolExecutor::no_tools();
let mut agent = Agent::new(provider, channel, registry, None, 5, executor);
agent.providers.provider_pool = vec![entry];
agent.providers.provider_config_snapshot = Some(snapshot);
let out = agent.handle_provider_command_as_string("ollama");
assert!(out.contains("already active"));
}
#[test]
fn provider_switch_missing_snapshot_returns_error() {
let mut qa = QuickTestAgent::minimal("ok");
let entry = make_entry("ollama", ProviderKind::Ollama, Some("qwen3:8b"));
qa.agent.providers.provider_pool = vec![entry];
let out = qa.agent.handle_provider_command_as_string("ollama");
assert!(out.contains("config snapshot missing"));
}
#[test]
fn provider_switch_success_resets_state() {
let entry_a = make_entry("ollama", ProviderKind::Ollama, Some("qwen3:8b"));
let entry_b = make_entry("ollama2", ProviderKind::Ollama, Some("llama3.2"));
let snapshot = ollama_snapshot();
let provider_a =
crate::provider_factory::build_provider_for_switch(&entry_a, &snapshot).unwrap();
let channel = MockChannel::new(vec![]);
let registry = create_test_registry();
let executor = MockToolExecutor::no_tools();
let mut agent = Agent::new(provider_a, channel, registry, None, 5, executor);
agent.providers.provider_pool = vec![entry_a, entry_b];
agent.providers.provider_config_snapshot = Some(snapshot);
agent.providers.cached_prompt_tokens = 999;
let out = agent.handle_provider_command_as_string("ollama2");
assert!(out.contains("Switched to provider:"), "unexpected: {out}");
assert!(out.contains("llama3.2"));
assert_eq!(
agent.providers.cached_prompt_tokens, 0,
"must be reset on switch"
);
assert_eq!(agent.runtime.model_name, "llama3.2");
}
#[test]
fn provider_status_no_metrics() {
let mut qa = QuickTestAgent::minimal("ok");
qa.agent.runtime.model_name = "test-model".to_owned();
let out = qa.agent.handle_provider_command_as_string("status");
assert!(out.contains("Current provider:"));
assert!(out.contains("test-model"));
}
#[tokio::test]
async fn provider_config_snapshot_fields() {
let snap = ProviderConfigSnapshot {
claude_api_key: Some("key-claude".to_owned()),
openai_api_key: Some("key-openai".to_owned()),
gemini_api_key: None,
compatible_api_keys: HashMap::default(),
llm_request_timeout_secs: 60,
embedding_model: "nomic-embed-text".to_owned(),
};
assert_eq!(snap.claude_api_key.as_deref(), Some("key-claude"));
assert_eq!(snap.openai_api_key.as_deref(), Some("key-openai"));
assert!(snap.gemini_api_key.is_none());
assert_eq!(snap.llm_request_timeout_secs, 60);
}
#[test]
fn build_switch_message_no_notice_when_same_provider() {
let provider = mock_provider(vec![]);
let channel = MockChannel::new(vec![]);
let registry = create_test_registry();
let executor = MockToolExecutor::no_tools();
let entry_a = make_entry("mock", ProviderKind::Ollama, Some("qwen3:8b"));
let entry_b = make_entry("mock2", ProviderKind::Ollama, Some("llama3.2"));
let snapshot = ollama_snapshot();
let provider_b =
crate::provider_factory::build_provider_for_switch(&entry_b, &snapshot).unwrap();
let mut agent = Agent::new(provider, channel, registry, None, 5, executor);
agent = agent.with_embedding_provider(provider_b.clone());
agent.runtime.active_provider_name = "mock2".to_owned();
agent.providers.provider_pool = vec![entry_a, entry_b];
agent.providers.provider_config_snapshot = Some(snapshot);
let msg = agent.build_switch_message("ollama");
assert!(
!msg.contains("Embedding operations"),
"no notice when embedding provider name == new chat provider name: {msg}"
);
}
#[test]
fn build_switch_message_includes_notice_when_embedding_provider_differs() {
let entry_a = make_entry("ollama", ProviderKind::Ollama, Some("qwen3:8b"));
let entry_b = make_entry("ollama2", ProviderKind::Ollama, Some("llama3.2"));
let snapshot = ollama_snapshot();
let provider_a =
crate::provider_factory::build_provider_for_switch(&entry_a, &snapshot).unwrap();
let embed_provider = mock_provider(vec![]);
let channel = MockChannel::new(vec![]);
let registry = create_test_registry();
let executor = MockToolExecutor::no_tools();
let mut agent = Agent::new(provider_a, channel, registry, None, 5, executor);
agent = agent.with_embedding_provider(embed_provider);
agent.providers.provider_pool = vec![entry_a, entry_b];
agent.providers.provider_config_snapshot = Some(snapshot);
let out = agent.handle_provider_command_as_string("ollama2");
assert!(
out.contains("Embedding operations continue using"),
"embedding notice expected when providers differ: {out}"
);
assert!(
out.contains("mock"),
"notice must name the embedding provider"
);
}
#[test]
fn provider_switch_does_not_change_embedding_provider() {
let entry_a = make_entry("ollama", ProviderKind::Ollama, Some("qwen3:8b"));
let entry_b = make_entry("ollama2", ProviderKind::Ollama, Some("llama3.2"));
let snapshot = ollama_snapshot();
let provider_a =
crate::provider_factory::build_provider_for_switch(&entry_a, &snapshot).unwrap();
let entry_embed = make_entry("embed", ProviderKind::Ollama, Some("nomic-embed-text"));
let embed_provider =
crate::provider_factory::build_provider_for_switch(&entry_embed, &snapshot).unwrap();
let channel = MockChannel::new(vec![]);
let registry = create_test_registry();
let executor = MockToolExecutor::no_tools();
let mut agent = Agent::new(provider_a, channel, registry, None, 5, executor);
agent = agent.with_embedding_provider(embed_provider);
agent.providers.provider_pool = vec![entry_a, entry_b];
agent.providers.provider_config_snapshot = Some(snapshot);
let embed_name_before = agent.embedding_provider.name().to_owned();
agent.handle_provider_command_as_string("ollama2");
assert_eq!(agent.runtime.model_name, "llama3.2");
assert_eq!(
agent.embedding_provider.name(),
embed_name_before,
"embedding_provider must not change after /provider switch"
);
}
}