use crate::registry::{ApiType, ProviderSpec};
use agent_diva_core::config::ProviderConfig;
use reqwest::StatusCode;
use serde::{Deserialize, Serialize};
use std::collections::BTreeSet;
use std::time::Duration;
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum ModelCatalogSource {
Runtime,
StaticFallback,
Unsupported,
Error,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub struct ProviderModelCatalog {
pub provider: String,
pub source: ModelCatalogSource,
pub runtime_supported: bool,
pub api_base: Option<String>,
pub models: Vec<String>,
#[serde(default)]
pub warnings: Vec<String>,
pub error: Option<String>,
}
#[derive(Debug, Clone)]
pub struct ProviderAccess {
pub api_key: Option<String>,
pub api_base: Option<String>,
pub extra_headers: Vec<(String, String)>,
}
impl ProviderAccess {
pub fn from_config(config: Option<&ProviderConfig>) -> Self {
let api_key = config
.map(|cfg| cfg.api_key.trim().to_string())
.filter(|value| !value.is_empty());
let api_base = config
.and_then(|cfg| cfg.api_base.as_ref())
.map(|value| value.trim().to_string())
.filter(|value| !value.is_empty());
let extra_headers = config
.and_then(|cfg| cfg.extra_headers.as_ref())
.map(|headers| {
headers
.iter()
.map(|(key, value)| (key.clone(), value.clone()))
.collect()
})
.unwrap_or_default();
Self {
api_key,
api_base,
extra_headers,
}
}
}
pub async fn fetch_provider_model_catalog(
spec: &ProviderSpec,
access: &ProviderAccess,
allow_static_fallback: bool,
) -> ProviderModelCatalog {
let api_base = effective_api_base(spec, access);
match runtime_strategy(spec, &api_base) {
Some(RuntimeDiscoveryStrategy::OpenAiCompatible) => {
match fetch_openai_compatible_models(spec, access, api_base.as_deref()).await {
Ok(models) => ProviderModelCatalog {
provider: spec.name.clone(),
source: ModelCatalogSource::Runtime,
runtime_supported: true,
api_base,
models,
warnings: vec![],
error: None,
},
Err(error) => fallback_or_error(spec, api_base, allow_static_fallback, error, true),
}
}
None => {
let message = format!(
"Provider '{}' does not support runtime model discovery in this build",
provider_identity(spec)
);
if allow_static_fallback && !spec.models.is_empty() {
ProviderModelCatalog {
provider: spec.name.clone(),
source: ModelCatalogSource::StaticFallback,
runtime_supported: false,
api_base,
models: spec.models.clone(),
warnings: vec![],
error: None,
}
} else {
ProviderModelCatalog {
provider: spec.name.clone(),
source: ModelCatalogSource::Unsupported,
runtime_supported: false,
api_base,
models: vec![],
warnings: vec![],
error: Some(message),
}
}
}
}
}
fn fallback_or_error(
spec: &ProviderSpec,
api_base: Option<String>,
allow_static_fallback: bool,
error: String,
runtime_supported: bool,
) -> ProviderModelCatalog {
if allow_static_fallback && !spec.models.is_empty() {
ProviderModelCatalog {
provider: spec.name.clone(),
source: ModelCatalogSource::StaticFallback,
runtime_supported,
api_base,
models: spec.models.clone(),
warnings: vec![error],
error: None,
}
} else {
ProviderModelCatalog {
provider: spec.name.clone(),
source: ModelCatalogSource::Error,
runtime_supported,
api_base,
models: vec![],
warnings: vec![],
error: Some(error),
}
}
}
#[derive(Debug, Clone, Copy)]
enum RuntimeDiscoveryStrategy {
OpenAiCompatible,
}
fn runtime_strategy(
spec: &ProviderSpec,
api_base: &Option<String>,
) -> Option<RuntimeDiscoveryStrategy> {
let supports_openai_compatible = matches!(spec.api_type, ApiType::Openai);
if supports_openai_compatible && api_base.is_some() {
Some(RuntimeDiscoveryStrategy::OpenAiCompatible)
} else {
None
}
}
fn effective_api_base(spec: &ProviderSpec, access: &ProviderAccess) -> Option<String> {
access
.api_base
.as_ref()
.map(|value| value.trim().trim_end_matches('/').to_string())
.filter(|value| !value.is_empty())
.or_else(|| {
(!spec.default_api_base.trim().is_empty()).then(|| {
spec.default_api_base
.trim()
.trim_end_matches('/')
.to_string()
})
})
}
async fn fetch_openai_compatible_models(
spec: &ProviderSpec,
access: &ProviderAccess,
api_base: Option<&str>,
) -> Result<Vec<String>, String> {
let api_base = api_base.ok_or_else(|| {
format!(
"Provider '{}' has no configured or default api_base for runtime discovery",
provider_identity(spec)
)
})?;
let client = crate::http_util::build_api_http_client(api_base, Duration::from_secs(15))
.map_err(|error| format!("failed to create HTTP client: {error}"))?;
let url = format!("{api_base}/models");
let mut request = client.get(&url);
if let Some(api_key) = access
.api_key
.as_ref()
.filter(|value| !value.trim().is_empty())
{
request = request.bearer_auth(api_key);
}
for (key, value) in &access.extra_headers {
request = request.header(key, value);
}
let response = request.send().await.map_err(|error| {
format!(
"Runtime model discovery request failed for provider '{}' at '{}': {}",
provider_identity(spec),
url,
error
)
})?;
let status = response.status();
if !status.is_success() {
let detail = response.text().await.unwrap_or_default();
let summary = http_error_summary(status, &detail);
return Err(format!(
"Runtime model discovery request returned {} for provider '{}' at '{}': {}",
status,
provider_identity(spec),
url,
summary
));
}
let payload: OpenAiModelsResponse = response.json().await.map_err(|error| {
format!(
"Runtime model discovery response was invalid JSON for provider '{}' at '{}': {}",
provider_identity(spec),
url,
error
)
})?;
let mut unique_models = BTreeSet::new();
for model in payload.data {
let id = model.id.trim();
if !id.is_empty() {
unique_models.insert(id.to_string());
}
}
Ok(unique_models.into_iter().collect())
}
fn http_error_summary(status: StatusCode, detail: &str) -> String {
let trimmed = detail.trim();
if trimmed.is_empty() {
return "empty response body".to_string();
}
let compact = trimmed.replace('\n', " ");
let preview: String = compact.chars().take(200).collect();
if compact.chars().count() > 200 {
format!("{status}: {preview}...")
} else {
format!("{status}: {preview}")
}
}
fn provider_identity(spec: &ProviderSpec) -> &str {
if spec.name.trim().is_empty() {
let display_name = spec.display_name.trim();
if !display_name.is_empty() {
display_name
} else {
"<unknown>"
}
} else {
spec.name.as_str()
}
}
#[derive(Debug, Deserialize)]
struct OpenAiModelsResponse {
#[serde(default)]
data: Vec<OpenAiModelEntry>,
}
#[derive(Debug, Deserialize)]
struct OpenAiModelEntry {
id: String,
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{registry::ApiType, ProviderSpec};
use mockito::{Matcher, Server};
use std::collections::HashMap;
fn openai_like_spec(name: &str, api_base: &str) -> ProviderSpec {
ProviderSpec {
name: name.to_string(),
api_type: ApiType::Openai,
keywords: vec![name.to_string()],
env_key: "TEST_API_KEY".to_string(),
display_name: name.to_string(),
default_model: None,
litellm_prefix: String::new(),
skip_prefixes: vec![],
env_extras: vec![],
default_api_base: api_base.to_string(),
supports_prompt_caching: false,
models: vec!["fallback-a".to_string(), "fallback-b".to_string()],
model_overrides: vec![],
}
}
#[tokio::test]
async fn fetch_provider_model_catalog_returns_runtime_models_for_openai_compatible() {
let mut server = Server::new_async().await;
let mock = server
.mock("GET", "/models")
.match_header("authorization", "Bearer sk-test")
.match_header("x-test", "1")
.with_status(200)
.with_body(r#"{"data":[{"id":"gpt-4o"},{"id":"gpt-4o-mini"},{"id":"gpt-4o"}]}"#)
.create_async()
.await;
let spec = openai_like_spec("openai", &server.url());
let access = ProviderAccess {
api_key: Some("sk-test".to_string()),
api_base: None,
extra_headers: vec![("x-test".to_string(), "1".to_string())],
};
let catalog = fetch_provider_model_catalog(&spec, &access, false).await;
mock.assert_async().await;
assert_eq!(catalog.source, ModelCatalogSource::Runtime);
assert_eq!(
catalog.models,
vec!["gpt-4o".to_string(), "gpt-4o-mini".to_string()]
);
assert_eq!(catalog.error, None);
}
#[tokio::test]
async fn fetch_provider_model_catalog_uses_static_fallback_on_http_error() {
let mut server = Server::new_async().await;
let mock = server
.mock("GET", "/models")
.with_status(401)
.with_body(r#"{"error":"unauthorized"}"#)
.create_async()
.await;
let spec = openai_like_spec("openai", &server.url());
let catalog = fetch_provider_model_catalog(
&spec,
&ProviderAccess {
api_key: Some("sk-bad".to_string()),
api_base: None,
extra_headers: vec![],
},
true,
)
.await;
mock.assert_async().await;
assert_eq!(catalog.source, ModelCatalogSource::StaticFallback);
assert_eq!(catalog.models, spec.models);
assert!(catalog.warnings[0].contains("401 Unauthorized"));
}
#[tokio::test]
async fn fetch_provider_model_catalog_returns_unsupported_without_fallback() {
let spec = ProviderSpec {
name: "anthropic".to_string(),
api_type: ApiType::Anthropic,
keywords: vec!["claude".to_string()],
env_key: "ANTHROPIC_API_KEY".to_string(),
display_name: "Anthropic".to_string(),
default_model: None,
litellm_prefix: String::new(),
skip_prefixes: vec![],
env_extras: vec![],
default_api_base: String::new(),
supports_prompt_caching: false,
models: vec![],
model_overrides: vec![],
};
let catalog =
fetch_provider_model_catalog(&spec, &ProviderAccess::from_config(None), false).await;
assert_eq!(catalog.source, ModelCatalogSource::Unsupported);
assert!(catalog.error.unwrap().contains("does not support"));
}
#[tokio::test]
async fn fetch_provider_model_catalog_supports_generic_openai_type_with_api_base() {
let mut server = Server::new_async().await;
let mock = server
.mock("GET", "/models")
.with_status(200)
.with_body(r#"{"data":[{"id":"grok-2-latest"}]}"#)
.create_async()
.await;
let spec = openai_like_spec("xai", &server.url());
let catalog = fetch_provider_model_catalog(
&spec,
&ProviderAccess {
api_key: None,
api_base: Some(server.url()),
extra_headers: vec![],
},
false,
)
.await;
mock.assert_async().await;
assert_eq!(catalog.source, ModelCatalogSource::Runtime);
assert_eq!(catalog.models, vec!["grok-2-latest".to_string()]);
}
#[test]
fn provider_access_from_config_filters_empty_values() {
let access = ProviderAccess::from_config(Some(&ProviderConfig {
api_key: " ".to_string(),
api_base: Some(" https://example.com/v1/ ".to_string()),
extra_headers: Some(HashMap::from([
("x-a".to_string(), "1".to_string()),
("x-b".to_string(), "2".to_string()),
])),
custom_models: vec![],
}));
assert_eq!(access.api_key, None);
assert_eq!(access.api_base, Some("https://example.com/v1/".to_string()));
assert_eq!(access.extra_headers.len(), 2);
}
#[tokio::test]
async fn fetch_provider_model_catalog_prefers_configured_api_base() {
let mut server = Server::new_async().await;
let mock = server
.mock("GET", "/models")
.match_query(Matcher::Missing)
.with_status(200)
.with_body(r#"{"data":[{"id":"model-a"}]}"#)
.create_async()
.await;
let spec = openai_like_spec("custom", "http://unused.example/v1");
let access = ProviderAccess {
api_key: None,
api_base: Some(server.url()),
extra_headers: vec![],
};
let catalog = fetch_provider_model_catalog(&spec, &access, false).await;
mock.assert_async().await;
assert_eq!(catalog.source, ModelCatalogSource::Runtime);
assert_eq!(catalog.api_base, Some(server.url()));
}
}