use axum::{
body::Body,
http::{header, HeaderMap, HeaderValue, Method, StatusCode},
response::{IntoResponse, Response},
};
use bytes::Bytes;
use futures::StreamExt;
use serde_json::Value;
use tokio::time::Duration;
use crate::proxy::state::AppState;
fn map_model_for_zai(original: &str, state: &crate::proxy::ZaiConfig) -> String {
let m = original.to_lowercase();
if let Some(mapped) = state.model_mapping.get(original) {
return mapped.clone();
}
if let Some(mapped) = state.model_mapping.get(&m) {
return mapped.clone();
}
if m.starts_with("zai:") {
return original[4..].to_string();
}
if m.starts_with("glm-") {
return original.to_string();
}
if !crate::proxy::common::model_mapping::is_claude_model(&m) {
return original.to_string();
}
if m.contains("opus") {
return state.models.opus.clone();
}
if m.contains("haiku") {
return state.models.haiku.clone();
}
state.models.sonnet.clone()
}
fn join_base_url(base: &str, path: &str) -> Result<String, String> {
let base = base.trim_end_matches('/');
let path = if path.starts_with('/') {
path.to_string()
} else {
format!("/{}", path)
};
Ok(format!("{}{}", base, path))
}
fn build_client(
upstream_proxy: Option<crate::proxy::config::UpstreamProxyConfig>,
timeout_secs: u64,
) -> Result<reqwest::Client, String> {
let mut builder = crate::utils::http::apply_tls_backend(reqwest::Client::builder())
.timeout(Duration::from_secs(timeout_secs.max(5)));
if let Some(config) = upstream_proxy {
if config.enabled && !config.url.is_empty() {
let proxy = reqwest::Proxy::all(&config.url)
.map_err(|e| format!("Invalid upstream proxy url: {}", e))?;
builder = builder.proxy(proxy);
}
}
builder
.tcp_nodelay(true)
.build()
.map_err(|e| format!("Failed to build HTTP client: {}", e))
}
fn copy_passthrough_headers(incoming: &HeaderMap) -> HeaderMap {
let mut out = HeaderMap::new();
for (k, v) in incoming.iter() {
let key = k.as_str().to_ascii_lowercase();
match key.as_str() {
"content-type" | "accept" | "anthropic-version" | "user-agent" => {
out.insert(k.clone(), v.clone());
}
"accept-encoding" | "cache-control" => {
out.insert(k.clone(), v.clone());
}
_ => {}
}
}
out
}
fn set_zai_auth(headers: &mut HeaderMap, incoming: &HeaderMap, api_key: &str) {
let has_x_api_key = incoming.contains_key("x-api-key");
let has_auth = incoming.contains_key(header::AUTHORIZATION);
if has_x_api_key || !has_auth {
if let Ok(v) = HeaderValue::from_str(api_key) {
headers.insert("x-api-key", v);
}
}
if has_auth {
if let Ok(v) = HeaderValue::from_str(&format!("Bearer {}", api_key)) {
headers.insert(header::AUTHORIZATION, v);
}
}
}
pub fn deep_remove_cache_control(value: &mut Value) {
match value {
Value::Object(map) => {
if let Some(v) = map.remove("cache_control") {
tracing::info!(" Deep Cleaning found nested cache_control: {:?}", v);
}
for v in map.values_mut() {
deep_remove_cache_control(v);
}
}
Value::Array(arr) => {
for v in arr {
deep_remove_cache_control(v);
}
}
_ => {}
}
}
pub async fn forward_anthropic_json(
state: &AppState,
method: Method,
path: &str,
incoming_headers: &HeaderMap,
mut body: Value,
message_count: usize,
) -> Response {
let zai = state.config.zai.read().await.clone();
if !zai.enabled || zai.dispatch_mode == crate::proxy::ZaiDispatchMode::Off {
return (StatusCode::BAD_REQUEST, "z.ai is disabled").into_response();
}
if zai.api_key.trim().is_empty() {
return (StatusCode::BAD_REQUEST, "z.ai api_key is not set").into_response();
}
if let Some(model) = body.get("model").and_then(|v| v.as_str()) {
let mapped = map_model_for_zai(model, &zai);
body["model"] = Value::String(mapped.clone());
if let Some(sig) = body
.get("thinking")
.and_then(|t| t.get("signature"))
.and_then(|s| s.as_str())
{
crate::proxy::SignatureCache::global().cache_session_signature(
"zai-session",
sig.to_string(),
message_count,
);
crate::proxy::SignatureCache::global().cache_thinking_family(sig.to_string(), mapped);
}
}
let url = match join_base_url(&zai.base_url, path) {
Ok(u) => u,
Err(e) => return (StatusCode::BAD_REQUEST, e).into_response(),
};
let timeout_secs = state.config.request_timeout_secs().max(5);
let upstream_proxy = state.config.upstream_proxy.read().await.clone();
let client = match build_client(Some(upstream_proxy), timeout_secs) {
Ok(c) => c,
Err(e) => return (StatusCode::INTERNAL_SERVER_ERROR, e).into_response(),
};
let mut headers = copy_passthrough_headers(incoming_headers);
set_zai_auth(&mut headers, incoming_headers, &zai.api_key);
headers
.entry(header::CONTENT_TYPE)
.or_insert(HeaderValue::from_static("application/json"));
if let Some(cc) = body.get("cache_control") {
tracing::info!(" Deep cleaning cache_control from ROOT: {:?}", cc);
}
deep_remove_cache_control(&mut body);
let body_bytes = serde_json::to_vec(&body).unwrap_or_default();
let body_len = body_bytes.len();
tracing::debug!(
"Forwarding request to z.ai (len: {} bytes): {}",
body_len,
url
);
let req = client
.request(method, &url)
.headers(headers)
.body(body_bytes);
let resp = match req.send().await {
Ok(r) => r,
Err(e) => {
return (
StatusCode::BAD_GATEWAY,
format!("Upstream request failed: {}", e),
)
.into_response();
}
};
let status = StatusCode::from_u16(resp.status().as_u16()).unwrap_or(StatusCode::BAD_GATEWAY);
let mut out = Response::builder().status(status);
if let Some(ct) = resp.headers().get(header::CONTENT_TYPE) {
out = out.header(header::CONTENT_TYPE, ct.clone());
}
let stream = resp.bytes_stream().map(|chunk| match chunk {
Ok(b) => Ok::<Bytes, std::io::Error>(b),
Err(e) => Ok(Bytes::from(format!("Upstream stream error: {}", e))),
});
out.body(Body::from_stream(stream)).unwrap_or_else(|_| {
(
StatusCode::INTERNAL_SERVER_ERROR,
"Failed to build response",
)
.into_response()
})
}