use crate::auth::auth_headers;
use crate::provider_overrides::{override_headers, resolve_overrides, ProviderOverrides};
use crate::{
http_client::build_http_client,
profile::{ApiFamily, ProviderProfile, RuntimeConfig},
provider::{Provider, ProviderFuture},
InferenceRequest, ProviderError, ProviderEvent, ProviderResult,
};
use futures::stream::BoxStream;
use reqwest::header::HeaderMap;
use std::sync::Arc;
#[derive(Debug, Clone)]
pub struct ProviderConnection {
profile: Arc<ProviderProfile>,
effective_base_url: String,
overrides: ProviderOverrides,
http_client: reqwest::Client,
}
impl ProviderConnection {
pub fn from_profile(profile: ProviderProfile, runtime: RuntimeConfig) -> ProviderResult<Self> {
Self::build(Arc::new(profile), runtime)
}
pub(crate) fn from_arc(
profile: Arc<ProviderProfile>,
runtime: RuntimeConfig,
) -> ProviderResult<Self> {
Self::build(profile, runtime)
}
fn build(profile: Arc<ProviderProfile>, runtime: RuntimeConfig) -> ProviderResult<Self> {
runtime.validate()?;
let context = format!("profile '{}'", profile.slug);
let kind = runtime.credential.kind();
let auth_strategy = profile
.auth_strategy_for(kind)
.ok_or_else(|| {
ProviderError::auth(format!(
"Provider '{}' does not support {:?} credentials",
profile.slug, kind
))
})?
.clone();
if let crate::profile::ProviderCredential::OAuthBearer {
expires_at: Some(exp),
..
} = &runtime.credential
{
if std::time::SystemTime::now() >= *exp {
return Err(ProviderError::auth(format!(
"OAuth credential for '{}' has expired",
profile.slug
)));
}
}
let auth_h = auth_headers(&runtime.credential, &auth_strategy, &context)?;
let overrides = resolve_overrides(&profile, &runtime);
let override_h = override_headers(&overrides)?;
let final_headers =
compose_headers(auth_h, override_h, &profile.default_headers, &context)?;
let http_client = build_http_client(
final_headers,
runtime.effective_connect_timeout(),
runtime.effective_read_timeout(),
)?;
let effective_base_url = runtime
.base_url_override
.unwrap_or_else(|| profile.base_url.clone());
Ok(Self {
profile,
effective_base_url,
overrides,
http_client,
})
}
pub fn profile(&self) -> &ProviderProfile {
&self.profile
}
pub fn effective_base_url(&self) -> &str {
&self.effective_base_url
}
}
fn compose_headers(
auth_headers: HeaderMap,
override_headers: HeaderMap,
default_headers: &std::collections::HashMap<String, String>,
context: &str,
) -> ProviderResult<HeaderMap> {
let mut final_headers = HeaderMap::new();
final_headers.insert(
reqwest::header::CONTENT_TYPE,
reqwest::header::HeaderValue::from_static("application/json"),
);
let mut protected: std::collections::HashSet<String> = std::collections::HashSet::new();
protected.insert(reqwest::header::CONTENT_TYPE.as_str().to_lowercase());
for (key, value) in &auth_headers {
protected.insert(key.as_str().to_lowercase());
final_headers.insert(key.clone(), value.clone());
}
for (key, value) in &override_headers {
protected.insert(key.as_str().to_lowercase());
final_headers.insert(key.clone(), value.clone());
}
for (key, value) in default_headers {
let hk = reqwest::header::HeaderName::from_bytes(key.as_bytes()).map_err(|e| {
ProviderError::invalid_request(format!(
"Invalid default header name '{}' for {}: {}",
key, context, e
))
})?;
let hv = reqwest::header::HeaderValue::from_str(value).map_err(|e| {
ProviderError::invalid_request(format!(
"Invalid default header value for '{}' on {}: {}",
key, context, e
))
})?;
if protected.contains(hk.as_str()) {
return Err(ProviderError::invalid_request(format!(
"Profile default header '{}' collides with protected header for {}",
key, context
)));
}
final_headers.insert(hk, hv);
}
Ok(final_headers)
}
impl Provider for ProviderConnection {
fn infer(&self, request: InferenceRequest) -> ProviderFuture<'_, Vec<ProviderEvent>> {
let client = self.http_client.clone();
let profile = Arc::clone(&self.profile);
let overrides = self.overrides.clone();
let effective_base_url = self.effective_base_url.clone();
Box::pin(async move {
match profile.family {
ApiFamily::Messages => {
crate::apis::messages::infer(client, &profile, &effective_base_url, request)
.await
}
ApiFamily::Completions => {
crate::apis::completions::infer(client, &profile, &effective_base_url, request)
.await
}
ApiFamily::Responses => {
crate::apis::responses::infer(
client,
&profile,
&overrides,
&effective_base_url,
request,
)
.await
}
}
})
}
fn infer_stream(
&self,
request: InferenceRequest,
) -> ProviderFuture<'_, BoxStream<'static, ProviderResult<ProviderEvent>>> {
let client = self.http_client.clone();
let profile = Arc::clone(&self.profile);
let overrides = self.overrides.clone();
let effective_base_url = self.effective_base_url.clone();
Box::pin(async move {
match profile.family {
ApiFamily::Messages => {
crate::apis::messages::infer_stream(
client,
&profile,
&effective_base_url,
request,
)
.await
}
ApiFamily::Completions => {
crate::apis::completions::infer_stream(
client,
&profile,
&effective_base_url,
request,
)
.await
}
ApiFamily::Responses => {
crate::apis::responses::infer_stream(
client,
&profile,
&overrides,
&effective_base_url,
request,
)
.await
}
}
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_protected_header_collision_fails() {
let mut auth = HeaderMap::new();
auth.insert(
reqwest::header::AUTHORIZATION,
reqwest::header::HeaderValue::from_static("Bearer secret"),
);
let mut defaults = std::collections::HashMap::new();
defaults.insert("authorization".to_string(), "new".to_string());
let result = compose_headers(auth, HeaderMap::new(), &defaults, "test");
assert!(result.is_err());
assert!(result
.unwrap_err()
.to_string()
.contains("collides with protected header"));
}
#[test]
fn test_content_type_collision_fails() {
let mut defaults = std::collections::HashMap::new();
defaults.insert("content-type".to_string(), "text/plain".to_string());
let result = compose_headers(HeaderMap::new(), HeaderMap::new(), &defaults, "test");
assert!(result.is_err());
assert!(result
.unwrap_err()
.to_string()
.contains("collides with protected header"));
}
#[test]
fn test_non_protected_header_allowed() {
let mut defaults = std::collections::HashMap::new();
defaults.insert("x-custom".to_string(), "value".to_string());
let result =
compose_headers(HeaderMap::new(), HeaderMap::new(), &defaults, "test").unwrap();
assert_eq!(result.get("x-custom").unwrap(), "value");
}
#[test]
fn test_override_header_is_protected() {
let mut overrides = HeaderMap::new();
overrides.insert(
"anthropic-version",
reqwest::header::HeaderValue::from_static("2023-06-01"),
);
let mut defaults = std::collections::HashMap::new();
defaults.insert("anthropic-version".to_string(), "2024-01".to_string());
let result = compose_headers(HeaderMap::new(), overrides, &defaults, "test");
assert!(result.is_err());
assert!(result
.unwrap_err()
.to_string()
.contains("collides with protected header"));
}
#[test]
fn test_local_connection_with_noauth() {
let registry = crate::registry::ProviderRegistry::default();
let result = registry.get("local", crate::profile::RuntimeConfig::none());
assert!(result.is_ok());
}
#[test]
fn test_local_connection_with_api_key() {
let registry = crate::registry::ProviderRegistry::default();
let result = registry.get("local", crate::profile::RuntimeConfig::new("my-token"));
assert!(result.is_ok());
}
#[test]
fn test_noauth_fails_for_api_key_provider() {
let registry = crate::registry::ProviderRegistry::default();
let result = registry.get("zai", crate::profile::RuntimeConfig::none());
assert!(result.is_err());
}
#[test]
fn test_noauth_empty_auth_headers_composition() {
let auth = HeaderMap::new();
let result = compose_headers(
auth,
HeaderMap::new(),
&std::collections::HashMap::new(),
"test",
)
.unwrap();
assert!(result.get("authorization").is_none());
assert!(result.get("x-api-key").is_none());
assert_eq!(result.get("content-type").unwrap(), "application/json");
}
#[test]
fn test_effective_base_url_uses_profile_default() {
let profile = crate::profile::ProviderProfile::new(
"test",
crate::profile::ApiFamily::Completions,
"https://api.example.com/v1",
);
let runtime = crate::profile::RuntimeConfig::new("test-key");
let conn = ProviderConnection::from_profile(profile, runtime).unwrap();
assert_eq!(conn.effective_base_url(), "https://api.example.com/v1");
assert_eq!(conn.profile().base_url, "https://api.example.com/v1");
}
#[test]
fn test_effective_base_url_uses_override() {
let profile = crate::profile::ProviderProfile::new(
"test",
crate::profile::ApiFamily::Completions,
"https://api.example.com/v1",
);
let runtime = crate::profile::RuntimeConfig::new("test-key")
.with_base_url("http://localhost:1234/v1");
let conn = ProviderConnection::from_profile(profile, runtime).unwrap();
assert_eq!(conn.effective_base_url(), "http://localhost:1234/v1");
assert_eq!(
conn.profile().base_url,
"https://api.example.com/v1",
"static profile metadata must remain unchanged"
);
}
#[test]
fn test_effective_base_url_preserves_profile_metadata() {
let profile = crate::profile::ProviderProfile::new(
"test",
crate::profile::ApiFamily::Completions,
"https://api.example.com/v1",
)
.with_credential_auth(
crate::profile::CredentialKind::NoAuth,
crate::profile::AuthStrategy::NoAuth,
);
let runtime = crate::profile::RuntimeConfig::none().with_base_url("http://custom:9999");
let conn = ProviderConnection::from_profile(profile.clone(), runtime).unwrap();
assert_eq!(conn.effective_base_url(), "http://custom:9999");
assert_eq!(
conn.profile().family,
crate::profile::ApiFamily::Completions
);
assert_eq!(conn.profile().slug, "test");
assert!(conn
.profile()
.supports_credential(crate::profile::CredentialKind::NoAuth));
assert_eq!(profile.base_url, "https://api.example.com/v1");
}
}