use axum::http::HeaderMap;
use llm_providers::get_endpoint;
use serde_json::Value;
use unigateway_runtime::host::ResolvedProvider;
use crate::config::{GatewayState, ServiceProvider};
pub fn normalize_base_url(url: &str) -> String {
let mut s = url.trim().to_string();
if s.is_empty() {
return s;
}
if !s.ends_with('/') {
s.push('/');
}
s
}
fn resolve_service_provider(sp: &ServiceProvider) -> Option<ResolvedProvider> {
let (base_url, family_id) = resolve_upstream(sp.base_url.clone(), sp.endpoint_id.as_deref())?;
let base_url = normalize_base_url(&base_url);
let api_key = sp.api_key.clone()?;
if api_key.is_empty() {
return None;
}
Some(ResolvedProvider {
name: sp.name.clone(),
provider_type: sp.provider_type.clone(),
endpoint_id: sp.endpoint_id.clone(),
base_url,
api_key,
family_id,
default_model: sp.default_model.clone(),
model_mapping: sp.model_mapping.clone(),
})
}
pub fn resolve_upstream(
provider_base_url: Option<String>,
endpoint_id: Option<&str>,
) -> Option<(String, Option<String>)> {
if let Some(eid) = endpoint_id {
let eid = eid.trim();
if !eid.is_empty() {
if let Some((family_id, endpoint)) = get_endpoint(eid) {
return Some((
normalize_base_url(endpoint.base_url),
Some(family_id.to_string()),
));
}
tracing::debug!(
"get_endpoint({:?}) returned None, falling back to provider base_url",
eid
);
}
}
let url = provider_base_url.as_deref()?.trim();
if url.is_empty() {
return None;
}
Some((normalize_base_url(url), None))
}
pub fn target_provider_hint(headers: &HeaderMap, payload: &Value) -> Option<String> {
let from_header = headers
.get("x-unigateway-provider")
.or_else(|| headers.get("x-target-vendor"))
.and_then(|v| v.to_str().ok())
.map(str::trim)
.filter(|v| !v.is_empty())
.map(ToString::to_string);
if from_header.is_some() {
return from_header;
}
payload
.get("target_vendor")
.or_else(|| payload.get("target_provider"))
.or_else(|| payload.get("provider"))
.and_then(|v| v.as_str())
.map(str::trim)
.filter(|v| !v.is_empty())
.map(ToString::to_string)
}
pub async fn resolve_providers(
gateway: &GatewayState,
service_id: &str,
protocol_hint: &str, target_hint: Option<&str>,
) -> Result<Vec<ResolvedProvider>, String> {
if let Some(hint) = target_hint {
let sp = if let Some(sp) = gateway
.select_provider_for_service_with_hint(service_id, "", hint)
.await
{
sp
} else {
gateway
.select_provider_for_service_with_hint(service_id, protocol_hint, hint)
.await
.ok_or_else(|| format!("no provider matches target '{hint}'"))?
};
let rp = resolve_service_provider(&sp)
.ok_or_else(|| format!("provider '{}': base_url or api_key missing", sp.name))?;
return Ok(vec![rp]);
}
let strategy = gateway.get_routing_strategy(service_id).await;
if strategy == "fallback" {
let mut all = gateway
.select_all_providers_for_service(service_id, "")
.await;
if all.is_empty() {
all = gateway
.select_all_providers_for_service(service_id, protocol_hint)
.await;
}
if all.is_empty() {
return Err(format!("no provider bound for service/{protocol_hint}"));
}
let resolved: Vec<ResolvedProvider> =
all.iter().filter_map(resolve_service_provider).collect();
if resolved.is_empty() {
return Err("all bound providers have missing base_url or api_key".to_string());
}
return Ok(resolved);
}
let mut sp = gateway.select_provider_for_service(service_id, "").await;
if sp.is_none() {
sp = gateway
.select_provider_for_service(service_id, protocol_hint)
.await;
}
let sp = sp.ok_or_else(|| format!("no provider bound for service/{protocol_hint}"))?;
let rp = resolve_service_provider(&sp)
.ok_or_else(|| format!("provider '{}': base_url or api_key missing", sp.name))?;
Ok(vec![rp])
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use super::{resolve_providers, resolve_upstream};
use crate::config::GatewayState;
use tempfile::tempdir;
#[test]
fn resolve_upstream_minimax_global() {
let r = resolve_upstream(None, Some("minimax:global"));
let (url, family) = r.expect("get_endpoint(minimax:global) should return Some");
assert!(
url.contains("minimax"),
"base_url should contain minimax: {}",
url
);
assert_eq!(family.as_deref(), Some("minimax"));
}
#[tokio::test]
async fn resolve_providers_allows_cross_protocol_target_hint() {
let dir = tempdir().expect("tempdir");
let config_path = dir.path().join("config.toml");
let state: Arc<GatewayState> = GatewayState::load(&config_path).await.expect("load state");
state.create_service("svc", "Service").await;
let provider_id = state
.create_provider(
"moonshot",
"openai",
"moonshot:global",
None,
"sk-test-moonshot",
None,
)
.await;
state
.bind_provider_to_service("svc", provider_id)
.await
.expect("bind provider");
let providers = resolve_providers(&state, "svc", "anthropic", Some("moonshot"))
.await
.expect("resolve providers");
assert_eq!(providers.len(), 1);
assert_eq!(providers[0].name, "moonshot");
assert_eq!(providers[0].provider_type, "openai");
}
#[tokio::test]
async fn resolve_providers_allows_cross_protocol_round_robin() {
let dir = tempdir().expect("tempdir");
let config_path = dir.path().join("config.toml");
let state: Arc<GatewayState> = GatewayState::load(&config_path).await.expect("load state");
state.create_service("svc", "Service").await;
let provider_id = state
.create_provider(
"moonshot",
"openai",
"moonshot:global",
None,
"sk-test-moonshot",
None,
)
.await;
state
.bind_provider_to_service("svc", provider_id)
.await
.expect("bind provider");
let providers = resolve_providers(&state, "svc", "anthropic", None)
.await
.expect("resolve providers");
assert_eq!(providers.len(), 1);
assert_eq!(providers[0].name, "moonshot");
assert_eq!(providers[0].provider_type, "openai");
}
}