use zeph_llm::any::AnyProvider;
use zeph_llm::claude::ClaudeProvider;
use zeph_llm::compatible::CompatibleProvider;
use zeph_llm::gemini::GeminiProvider;
use zeph_llm::http::llm_client;
use zeph_llm::ollama::OllamaProvider;
use zeph_llm::openai::OpenAiProvider;
use crate::agent::state::ProviderConfigSnapshot;
use crate::config::{Config, ProviderEntry, ProviderKind};
#[derive(Debug, thiserror::Error)]
pub enum BootstrapError {
#[error("config error: {0}")]
Config(#[from] crate::config::ConfigError),
#[error("provider error: {0}")]
Provider(String),
#[error("memory error: {0}")]
Memory(String),
#[error("vault init error: {0}")]
VaultInit(crate::vault::AgeVaultError),
#[error("I/O error: {0}")]
Io(#[from] std::io::Error),
}
pub fn build_provider_for_switch(
entry: &ProviderEntry,
snapshot: &ProviderConfigSnapshot,
) -> Result<AnyProvider, BootstrapError> {
use zeph_common::secret::Secret;
let mut config = Config::default();
config.secrets.claude_api_key = snapshot.claude_api_key.as_deref().map(Secret::new);
config.secrets.openai_api_key = snapshot.openai_api_key.as_deref().map(Secret::new);
config.secrets.gemini_api_key = snapshot.gemini_api_key.as_deref().map(Secret::new);
config.secrets.compatible_api_keys = snapshot
.compatible_api_keys
.iter()
.map(|(k, v)| (k.clone(), Secret::new(v.as_str())))
.collect();
config.timeouts.llm_request_timeout_secs = snapshot.llm_request_timeout_secs;
config
.llm
.embedding_model
.clone_from(&snapshot.embedding_model);
build_provider_from_entry(entry, &config)
}
#[allow(clippy::too_many_lines)]
pub fn build_provider_from_entry(
entry: &ProviderEntry,
config: &Config,
) -> Result<AnyProvider, BootstrapError> {
match entry.provider_type {
ProviderKind::Ollama => {
let base_url = entry
.base_url
.as_deref()
.unwrap_or("http://localhost:11434");
let model = entry.model.as_deref().unwrap_or("qwen3:8b").to_owned();
let embed = entry
.embedding_model
.clone()
.unwrap_or_else(|| config.llm.embedding_model.clone());
let mut provider = OllamaProvider::new(base_url, model, embed);
if let Some(ref vm) = entry.vision_model {
provider = provider.with_vision_model(vm.clone());
}
Ok(AnyProvider::Ollama(provider))
}
ProviderKind::Claude => {
let api_key = config
.secrets
.claude_api_key
.as_ref()
.ok_or_else(|| {
BootstrapError::Provider("ZEPH_CLAUDE_API_KEY not found in vault".into())
})?
.expose()
.to_owned();
let model = entry
.model
.clone()
.unwrap_or_else(|| "claude-haiku-4-5-20251001".to_owned());
let max_tokens = entry.max_tokens.unwrap_or(4096);
let provider = ClaudeProvider::new(api_key, model, max_tokens)
.with_client(llm_client(config.timeouts.llm_request_timeout_secs))
.with_extended_context(entry.enable_extended_context)
.with_thinking_opt(entry.thinking.clone())
.map_err(|e| BootstrapError::Provider(format!("invalid thinking config: {e}")))?
.with_server_compaction(entry.server_compaction);
Ok(AnyProvider::Claude(provider))
}
ProviderKind::OpenAi => {
let api_key = config
.secrets
.openai_api_key
.as_ref()
.ok_or_else(|| {
BootstrapError::Provider("ZEPH_OPENAI_API_KEY not found in vault".into())
})?
.expose()
.to_owned();
let base_url = entry
.base_url
.clone()
.unwrap_or_else(|| "https://api.openai.com/v1".to_owned());
let model = entry
.model
.clone()
.unwrap_or_else(|| "gpt-4o-mini".to_owned());
let max_tokens = entry.max_tokens.unwrap_or(4096);
Ok(AnyProvider::OpenAi(
OpenAiProvider::new(
api_key,
base_url,
model,
max_tokens,
entry.embedding_model.clone(),
entry.reasoning_effort.clone(),
)
.with_client(llm_client(config.timeouts.llm_request_timeout_secs)),
))
}
ProviderKind::Gemini => {
let api_key = config
.secrets
.gemini_api_key
.as_ref()
.ok_or_else(|| {
BootstrapError::Provider("ZEPH_GEMINI_API_KEY not found in vault".into())
})?
.expose()
.to_owned();
let model = entry
.model
.clone()
.unwrap_or_else(|| "gemini-2.0-flash".to_owned());
let max_tokens = entry.max_tokens.unwrap_or(8192);
let base_url = entry
.base_url
.clone()
.unwrap_or_else(|| "https://generativelanguage.googleapis.com".to_owned());
let mut provider = GeminiProvider::new(api_key, model, max_tokens)
.with_base_url(base_url)
.with_client(llm_client(config.timeouts.llm_request_timeout_secs));
if let Some(ref em) = entry.embedding_model {
provider = provider.with_embedding_model(em.clone());
}
if let Some(level) = entry.thinking_level {
provider = provider.with_thinking_level(level);
}
if let Some(budget) = entry.thinking_budget {
provider = provider
.with_thinking_budget(budget)
.map_err(|e| BootstrapError::Provider(e.to_string()))?;
}
if let Some(include) = entry.include_thoughts {
provider = provider.with_include_thoughts(include);
}
Ok(AnyProvider::Gemini(provider))
}
ProviderKind::Compatible => {
let name = entry.name.as_deref().ok_or_else(|| {
BootstrapError::Provider(
"compatible provider requires 'name' field in [[llm.providers]]".into(),
)
})?;
let base_url = entry.base_url.clone().ok_or_else(|| {
BootstrapError::Provider(format!(
"compatible provider '{name}' requires 'base_url'"
))
})?;
let model = entry.model.clone().unwrap_or_default();
let api_key = entry.api_key.clone().unwrap_or_else(|| {
config
.secrets
.compatible_api_keys
.get(name)
.map(|s| s.expose().to_owned())
.unwrap_or_default()
});
let max_tokens = entry.max_tokens.unwrap_or(4096);
Ok(AnyProvider::Compatible(CompatibleProvider::new(
name.to_owned(),
api_key,
base_url,
model,
max_tokens,
entry.embedding_model.clone(),
)))
}
#[cfg(feature = "candle")]
ProviderKind::Candle => {
let candle = entry.candle.as_ref().ok_or_else(|| {
BootstrapError::Provider(
"candle provider requires 'candle' section in [[llm.providers]]".into(),
)
})?;
let source = match candle.source.as_str() {
"local" => zeph_llm::candle_provider::loader::ModelSource::Local {
path: std::path::PathBuf::from(&candle.local_path),
},
_ => zeph_llm::candle_provider::loader::ModelSource::HuggingFace {
repo_id: entry
.model
.clone()
.unwrap_or_else(|| config.llm.effective_model().to_owned()),
filename: candle.filename.clone(),
},
};
let template =
zeph_llm::candle_provider::template::ChatTemplate::parse_str(&candle.chat_template);
let gen_config = zeph_llm::candle_provider::generate::GenerationConfig {
temperature: candle.generation.temperature,
top_p: candle.generation.top_p,
top_k: candle.generation.top_k,
max_tokens: candle.generation.capped_max_tokens(),
seed: candle.generation.seed,
repeat_penalty: candle.generation.repeat_penalty,
repeat_last_n: candle.generation.repeat_last_n,
};
let device = select_device(&candle.device)?;
let inference_timeout =
std::time::Duration::from_secs(candle.inference_timeout_secs.max(1));
zeph_llm::candle_provider::CandleProvider::new_with_timeout(
&source,
template,
gen_config,
candle.embedding_repo.as_deref(),
candle.hf_token.as_deref(),
device,
inference_timeout,
)
.map(AnyProvider::Candle)
.map_err(|e| BootstrapError::Provider(e.to_string()))
}
#[cfg(not(feature = "candle"))]
ProviderKind::Candle => Err(BootstrapError::Provider(
"candle feature is not enabled".into(),
)),
}
}
#[cfg(feature = "candle")]
pub fn select_device(
preference: &str,
) -> Result<zeph_llm::candle_provider::Device, BootstrapError> {
match preference {
"metal" => {
#[cfg(feature = "metal")]
return zeph_llm::candle_provider::Device::new_metal(0)
.map_err(|e| BootstrapError::Provider(e.to_string()));
#[cfg(not(feature = "metal"))]
return Err(BootstrapError::Provider(
"candle compiled without metal feature".into(),
));
}
"cuda" => {
#[cfg(feature = "cuda")]
return zeph_llm::candle_provider::Device::new_cuda(0)
.map_err(|e| BootstrapError::Provider(e.to_string()));
#[cfg(not(feature = "cuda"))]
return Err(BootstrapError::Provider(
"candle compiled without cuda feature".into(),
));
}
"auto" => {
#[cfg(feature = "metal")]
if let Ok(device) = zeph_llm::candle_provider::Device::new_metal(0) {
return Ok(device);
}
#[cfg(feature = "cuda")]
if let Ok(device) = zeph_llm::candle_provider::Device::new_cuda(0) {
return Ok(device);
}
Ok(zeph_llm::candle_provider::Device::Cpu)
}
_ => Ok(zeph_llm::candle_provider::Device::Cpu),
}
}
#[must_use]
pub fn effective_embedding_model(config: &Config) -> String {
if let Some(m) = config
.llm
.providers
.iter()
.find(|e| e.embed)
.and_then(|e| e.embedding_model.as_ref())
{
return m.clone();
}
if let Some(m) = config
.llm
.providers
.first()
.and_then(|e| e.embedding_model.as_ref())
{
return m.clone();
}
config.llm.embedding_model.clone()
}
#[cfg(test)]
mod tests {
#[cfg(feature = "candle")]
use super::select_device;
#[cfg(feature = "candle")]
#[test]
fn select_device_cpu_default() {
let device = select_device("cpu").unwrap();
assert!(matches!(device, zeph_llm::candle_provider::Device::Cpu));
}
#[cfg(feature = "candle")]
#[test]
fn select_device_unknown_defaults_to_cpu() {
let device = select_device("unknown").unwrap();
assert!(matches!(device, zeph_llm::candle_provider::Device::Cpu));
}
#[cfg(all(feature = "candle", not(feature = "metal")))]
#[test]
fn select_device_metal_without_feature_errors() {
let result = select_device("metal");
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("metal feature"));
}
#[cfg(all(feature = "candle", not(feature = "cuda")))]
#[test]
fn select_device_cuda_without_feature_errors() {
let result = select_device("cuda");
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("cuda feature"));
}
#[cfg(feature = "candle")]
#[test]
fn select_device_auto_fallback() {
let device = select_device("auto").unwrap();
assert!(matches!(
device,
zeph_llm::candle_provider::Device::Cpu
| zeph_llm::candle_provider::Device::Cuda(_)
| zeph_llm::candle_provider::Device::Metal(_)
));
}
}