use crate::admin::state::SharedState;
use anyllm_providers::{
all_providers, get_provider, list_models,
model::ModelStatus,
provider::{AuthKind, ProviderProtocol, ProviderStatus},
};
use axum::{extract::Path, extract::State, http::StatusCode, response::IntoResponse, Json};
use std::sync::LazyLock;
static REFRESH_CLIENT: LazyLock<reqwest::Client> = LazyLock::new(|| {
reqwest::Client::builder()
.connect_timeout(std::time::Duration::from_secs(10))
.timeout(std::time::Duration::from_secs(20))
.redirect(reqwest::redirect::Policy::none())
.build()
.expect("failed to build provider refresh HTTP client")
});
fn protocol_str(p: ProviderProtocol) -> &'static str {
match p {
ProviderProtocol::OpenAICompat => "openai_compat",
ProviderProtocol::AzureOpenAI => "azure_openai",
ProviderProtocol::VertexAI => "vertex_ai",
ProviderProtocol::GeminiOpenAI => "gemini_openai",
ProviderProtocol::GeminiNative => "gemini_native",
ProviderProtocol::AnthropicNative => "anthropic_native",
ProviderProtocol::BedrockNative => "bedrock_native",
ProviderProtocol::Custom => "custom",
}
}
fn auth_str(a: AuthKind) -> &'static str {
match a {
AuthKind::Bearer => "bearer",
AuthKind::GoogleApiKey => "google_api_key",
AuthKind::AzureApiKey => "azure_api_key",
AuthKind::AwsSigV4 => "aws_sigv4",
AuthKind::None => "none",
}
}
fn provider_status_str(s: ProviderStatus) -> &'static str {
match s {
ProviderStatus::Implemented => "implemented",
ProviderStatus::Wired => "wired",
ProviderStatus::Stub => "stub",
}
}
fn model_status_str(s: ModelStatus) -> &'static str {
match s {
ModelStatus::Available => "available",
ModelStatus::Deprecated => "deprecated",
ModelStatus::Stub => "stub",
}
}
pub(super) async fn list_providers(State(shared): State<SharedState>) -> impl IntoResponse {
let cache_stats = if let Ok(db_guard) = shared.db.lock() {
crate::admin::db::get_all_provider_cache_stats(&db_guard).unwrap_or_default()
} else {
Default::default()
};
let providers: Vec<serde_json::Value> = all_providers()
.map(|p| {
let model_count = list_models(p.id).len();
let (cached_count, last_refreshed) = cache_stats
.get(p.id)
.map(|(c, r)| (*c, *r))
.unwrap_or((0, None));
serde_json::json!({
"id": p.id,
"display_name": p.display_name,
"protocol": protocol_str(p.protocol),
"auth": auth_str(p.auth),
"status": provider_status_str(p.status),
"default_base_url": p.default_base_url,
"env_vars": p.env_vars,
"litellm_prefix": p.litellm_prefix,
"capabilities": {
"chat_completions": p.capabilities.chat_completions,
"streaming": p.capabilities.streaming,
"tool_use": p.capabilities.tool_use,
"embeddings": p.capabilities.embeddings,
"vision": p.capabilities.vision,
"batch": p.capabilities.batch,
},
"model_count": model_count,
"cached_model_count": cached_count,
"last_refreshed": last_refreshed,
})
})
.collect();
Json(serde_json::json!({ "providers": providers })).into_response()
}
pub(super) async fn list_provider_models(
State(shared): State<SharedState>,
Path(provider_id): Path<String>,
) -> impl IntoResponse {
if !super::is_safe_model_name(&provider_id) {
return (
StatusCode::BAD_REQUEST,
Json(serde_json::json!({ "error": "invalid provider id" })),
)
.into_response();
}
if get_provider(&provider_id).is_none() {
return (
StatusCode::NOT_FOUND,
Json(serde_json::json!({ "error": "provider not found" })),
)
.into_response();
}
let models: Vec<serde_json::Value> = list_models(&provider_id)
.iter()
.map(|m| {
let pricing = crate::cost::price_per_million_for_model(m.id).map(|(inp, out)| {
serde_json::json!({
"input_per_million_tokens": inp,
"output_per_million_tokens": out,
})
});
serde_json::json!({
"id": m.id,
"context_window": m.context_window,
"max_output_tokens": m.max_output_tokens,
"status": model_status_str(m.status),
"capabilities": {
"streaming": m.capabilities.streaming,
"tool_use": m.capabilities.tool_use,
"vision": m.capabilities.vision,
"extended_thinking": m.capabilities.extended_thinking,
},
"pricing": pricing,
})
})
.collect();
let cached_models: Vec<String> = if let Ok(db_guard) = shared.db.lock() {
crate::admin::db::list_cached_provider_models(&db_guard, &provider_id).unwrap_or_default()
} else {
vec![]
};
let has_models = !models.is_empty();
Json(serde_json::json!({
"provider_id": provider_id,
"has_models": has_models,
"models": models,
"cached_models": cached_models,
}))
.into_response()
}
pub(super) async fn refresh_provider_models(
State(shared): State<SharedState>,
Path(provider_id): Path<String>,
) -> impl IntoResponse {
if !super::is_safe_model_name(&provider_id) {
return (
StatusCode::BAD_REQUEST,
Json(serde_json::json!({ "error": "invalid provider id" })),
)
.into_response();
}
let provider = match get_provider(&provider_id) {
Some(p) => p,
None => {
return (
StatusCode::NOT_FOUND,
Json(serde_json::json!({ "error": "provider not found" })),
)
.into_response()
}
};
if !provider.capabilities.chat_completions {
return (
StatusCode::UNPROCESSABLE_ENTITY,
Json(serde_json::json!({
"error": "provider does not support model discovery"
})),
)
.into_response();
}
let api_key = provider.env_vars.iter().find_map(|v| std::env::var(v).ok());
let url = format!(
"{}/v1/models",
provider.default_base_url.trim_end_matches('/')
);
let mut req = REFRESH_CLIENT.get(&url);
if let Some(ref key) = api_key {
req = req.header("Authorization", format!("Bearer {key}"));
}
let resp = match req.send().await {
Ok(r) => r,
Err(e) => {
let msg = if e.is_connect() {
format!("connection refused: {url}")
} else if e.is_timeout() {
format!("request timed out: {url}")
} else {
format!("request failed: {e}")
};
return (
StatusCode::BAD_GATEWAY,
Json(serde_json::json!({ "error": msg })),
)
.into_response();
}
};
if resp.status() == reqwest::StatusCode::UNAUTHORIZED {
return (
StatusCode::UNAUTHORIZED,
Json(serde_json::json!({
"error": "API key required — set the provider env var and restart, \
or configure the key via Settings"
})),
)
.into_response();
}
if !resp.status().is_success() {
return (
StatusCode::BAD_GATEWAY,
Json(serde_json::json!({
"error": format!("upstream returned {}", resp.status())
})),
)
.into_response();
}
let json: serde_json::Value = match resp.json().await {
Ok(v) => v,
Err(e) => {
return (
StatusCode::BAD_GATEWAY,
Json(serde_json::json!({ "error": format!("invalid JSON: {e}") })),
)
.into_response()
}
};
let model_ids: Vec<String> = json
.get("data")
.and_then(|d| d.as_array())
.map(|arr| {
arr.iter()
.filter_map(|m| m.get("id")?.as_str().map(String::from))
.collect()
})
.unwrap_or_default();
let count = model_ids.len();
let db_arc = shared.db.clone();
let pid = provider_id.clone();
let ids = model_ids.clone();
let _ = tokio::task::spawn_blocking(move || {
let mut conn = db_arc.lock().unwrap_or_else(|e| e.into_inner());
if let Err(e) = crate::admin::db::upsert_provider_models_cache(&mut conn, &pid, &ids) {
tracing::warn!(provider = %pid, error = %e, "failed to cache provider models");
}
})
.await;
tracing::info!(provider = %provider_id, count = count, "refreshed provider model cache");
(
StatusCode::OK,
Json(serde_json::json!({
"provider_id": provider_id,
"count": count,
"models": model_ids,
})),
)
.into_response()
}