use axum::body::Body;
use axum::extract::{Request, State};
use axum::http::{HeaderMap, HeaderValue, StatusCode};
use axum::response::{IntoResponse, Response};
use futures_util::StreamExt;
use log_lazy::LogLazy;
use reqwest::Client;
use crate::oauth::OAuthProvider;
use crate::token::TokenManager;
#[derive(Clone)]
pub struct AppState {
pub client: Client,
pub token_manager: TokenManager,
pub oauth_provider: OAuthProvider,
pub upstream_base_url: String,
pub logger: LogLazy,
}
pub const API_PREFIX: &str = "/api/latest/anthropic/";
pub const REQUIRED_FORWARD_HEADERS: &[&str] = &[
"anthropic-beta",
"anthropic-version",
"x-claude-code-session-id",
];
const HOP_BY_HOP_HEADERS: &[&str] = &["host", "connection", "transfer-encoding", "keep-alive"];
#[allow(clippy::unused_async)]
pub async fn health() -> impl IntoResponse {
(StatusCode::OK, "ok")
}
#[allow(clippy::unused_async)]
pub async fn issue_token(
State(state): State<AppState>,
axum::Json(req): axum::Json<IssueTokenRequest>,
) -> impl IntoResponse {
let ttl = req.ttl_hours.unwrap_or(24);
let label = req.label.unwrap_or_default();
match state.token_manager.issue_token(ttl, &label) {
Ok(token) => (
StatusCode::OK,
axum::Json(serde_json::json!({
"token": token,
"ttl_hours": ttl,
"label": label,
})),
)
.into_response(),
Err(e) => (
StatusCode::INTERNAL_SERVER_ERROR,
axum::Json(serde_json::json!({
"error": format!("Failed to issue token: {e}")
})),
)
.into_response(),
}
}
#[derive(serde::Deserialize)]
pub struct IssueTokenRequest {
pub ttl_hours: Option<i64>,
pub label: Option<String>,
}
pub async fn proxy_handler(State(state): State<AppState>, req: Request) -> impl IntoResponse {
let path = req.uri().path().to_string();
let method = req.method().clone();
state.logger.verbose(|| format!("Incoming {method} {path}"));
let upstream_path = resolve_upstream_path(&path);
state
.logger
.debug(|| format!("Resolved upstream path: {upstream_path}"));
let upstream_url = format!(
"{}{}",
state.upstream_base_url.trim_end_matches('/'),
upstream_path
);
let upstream_url = if let Some(query) = req.uri().query() {
format!("{upstream_url}?{query}")
} else {
upstream_url
};
if let Some(session_id) = req.headers().get("x-claude-code-session-id") {
state
.logger
.verbose(|| format!("Session: {}", session_id.to_str().unwrap_or("<invalid>")));
}
let auth_header = req
.headers()
.get("authorization")
.and_then(|v| v.to_str().ok())
.and_then(|v| v.strip_prefix("Bearer "));
let Some(token) = auth_header else {
state.logger.debug(|| "Missing Authorization header");
return error_response(
StatusCode::UNAUTHORIZED,
"authentication_error",
"Missing Authorization header with Bearer token",
);
};
let custom_token = token.to_string();
if let Err(e) = state.token_manager.validate_token(&custom_token) {
let status = match &e {
crate::token::TokenError::Revoked => StatusCode::FORBIDDEN,
_ => StatusCode::UNAUTHORIZED,
};
state
.logger
.debug(|| format!("Token validation failed: {e}"));
return error_response(status, "authentication_error", &format!("{e}"));
}
let oauth_token = match state.oauth_provider.get_token() {
Ok(token) => token,
Err(e) => {
tracing::error!("Failed to get OAuth token: {e}");
return error_response(
StatusCode::BAD_GATEWAY,
"api_error",
"Upstream authentication unavailable",
);
}
};
let upstream_headers = build_upstream_headers(req.headers(), &oauth_token, &state.logger);
let body_bytes = match axum::body::to_bytes(req.into_body(), 10 * 1024 * 1024).await {
Ok(bytes) => bytes,
Err(e) => {
return error_response(
StatusCode::BAD_REQUEST,
"invalid_request_error",
&format!("Failed to read request body: {e}"),
);
}
};
state.logger.verbose(|| {
format!(
"Forwarding {method} {upstream_url} ({} bytes)",
body_bytes.len()
)
});
let upstream_req = state
.client
.request(method, &upstream_url)
.headers(upstream_headers)
.body(body_bytes);
let upstream_resp = match upstream_req.send().await {
Ok(resp) => resp,
Err(e) => {
tracing::error!("Upstream request failed: {e}");
return error_response(
StatusCode::BAD_GATEWAY,
"api_error",
&format!("Upstream request failed: {e}"),
);
}
};
let status = StatusCode::from_u16(upstream_resp.status().as_u16())
.unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
state
.logger
.verbose(|| format!("Upstream responded: {status}"));
let mut response_headers = HeaderMap::new();
for (name, value) in upstream_resp.headers() {
let name_lower = name.as_str().to_lowercase();
if HOP_BY_HOP_HEADERS.contains(&name_lower.as_str()) {
continue;
}
response_headers.insert(name.clone(), value.clone());
}
let stream = upstream_resp
.bytes_stream()
.map(|chunk| chunk.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e)));
let body = Body::from_stream(stream);
let mut response = Response::new(body);
*response.status_mut() = status;
*response.headers_mut() = response_headers;
response
}
#[must_use]
pub fn resolve_upstream_path(path: &str) -> String {
if let Some(rest) = path.strip_prefix("/api/latest/anthropic") {
return rest.to_string();
}
path.to_string()
}
fn build_upstream_headers(incoming: &HeaderMap, oauth_token: &str, logger: &LogLazy) -> HeaderMap {
let mut headers = HeaderMap::new();
for (name, value) in incoming {
let name_lower = name.as_str().to_lowercase();
if name_lower == "authorization" || HOP_BY_HOP_HEADERS.contains(&name_lower.as_str()) {
continue;
}
headers.insert(name.clone(), value.clone());
}
if let Ok(auth_val) = HeaderValue::from_str(&format!("Bearer {oauth_token}")) {
headers.insert("authorization", auth_val);
}
for &header_name in REQUIRED_FORWARD_HEADERS {
if let Some(val) = headers.get(header_name) {
logger.trace(|| {
format!(
"Forwarding {header_name}: {}",
val.to_str().unwrap_or("<non-utf8>")
)
});
}
}
headers
}
fn error_response(status: StatusCode, error_type: &str, message: &str) -> Response {
(
status,
axum::Json(serde_json::json!({
"type": "error",
"error": {
"type": error_type,
"message": message
}
})),
)
.into_response()
}