use axum::body::Body;
use axum::extract::{Request, State};
use axum::http::{HeaderMap, HeaderValue, StatusCode};
use axum::response::{IntoResponse, Response};
use futures_util::StreamExt;
use reqwest::Client;
use crate::oauth::OAuthProvider;
use crate::token::TokenManager;
#[derive(Clone)]
pub struct AppState {
pub client: Client,
pub token_manager: TokenManager,
pub oauth_provider: OAuthProvider,
pub upstream_base_url: String,
}
pub const API_PREFIX: &str = "/api/latest/anthropic/";
#[allow(clippy::unused_async)]
pub async fn health() -> impl IntoResponse {
(StatusCode::OK, "ok")
}
#[allow(clippy::unused_async)]
pub async fn issue_token(
State(state): State<AppState>,
axum::Json(req): axum::Json<IssueTokenRequest>,
) -> impl IntoResponse {
let ttl = req.ttl_hours.unwrap_or(24);
let label = req.label.unwrap_or_default();
match state.token_manager.issue_token(ttl, &label) {
Ok(token) => (
StatusCode::OK,
axum::Json(serde_json::json!({
"token": token,
"ttl_hours": ttl,
"label": label,
})),
)
.into_response(),
Err(e) => (
StatusCode::INTERNAL_SERVER_ERROR,
axum::Json(serde_json::json!({
"error": format!("Failed to issue token: {e}")
})),
)
.into_response(),
}
}
#[derive(serde::Deserialize)]
pub struct IssueTokenRequest {
pub ttl_hours: Option<i64>,
pub label: Option<String>,
}
pub async fn proxy_handler(State(state): State<AppState>, req: Request) -> impl IntoResponse {
let path = req.uri().path();
let downstream_path = path.strip_prefix("/api/latest/anthropic").unwrap_or(path);
let upstream_url = format!(
"{}{}",
state.upstream_base_url.trim_end_matches('/'),
downstream_path
);
let upstream_url = if let Some(query) = req.uri().query() {
format!("{upstream_url}?{query}")
} else {
upstream_url
};
let auth_header = req
.headers()
.get("authorization")
.and_then(|v| v.to_str().ok())
.and_then(|v| v.strip_prefix("Bearer "));
let custom_token = match auth_header {
Some(token) => token.to_string(),
None => {
return (
StatusCode::UNAUTHORIZED,
axum::Json(serde_json::json!({
"type": "error",
"error": {
"type": "authentication_error",
"message": "Missing Authorization header with Bearer token"
}
})),
)
.into_response();
}
};
if let Err(e) = state.token_manager.validate_token(&custom_token) {
let status = match &e {
crate::token::TokenError::Revoked => StatusCode::FORBIDDEN,
_ => StatusCode::UNAUTHORIZED,
};
return (
status,
axum::Json(serde_json::json!({
"type": "error",
"error": {
"type": "authentication_error",
"message": format!("{e}")
}
})),
)
.into_response();
}
let oauth_token = match state.oauth_provider.get_token() {
Ok(token) => token,
Err(e) => {
tracing::error!("Failed to get OAuth token: {e}");
return (
StatusCode::BAD_GATEWAY,
axum::Json(serde_json::json!({
"type": "error",
"error": {
"type": "api_error",
"message": "Upstream authentication unavailable"
}
})),
)
.into_response();
}
};
let method = req.method().clone();
let mut upstream_headers = HeaderMap::new();
for (name, value) in req.headers() {
let name_str = name.as_str().to_lowercase();
if matches!(
name_str.as_str(),
"host" | "authorization" | "connection" | "transfer-encoding"
) {
continue;
}
upstream_headers.insert(name.clone(), value.clone());
}
if let Ok(auth_val) = HeaderValue::from_str(&format!("Bearer {oauth_token}")) {
upstream_headers.insert("authorization", auth_val);
}
let body_bytes = match axum::body::to_bytes(req.into_body(), 10 * 1024 * 1024).await {
Ok(bytes) => bytes,
Err(e) => {
return (
StatusCode::BAD_REQUEST,
axum::Json(serde_json::json!({
"type": "error",
"error": {
"type": "invalid_request_error",
"message": format!("Failed to read request body: {e}")
}
})),
)
.into_response();
}
};
let upstream_req = state
.client
.request(method, &upstream_url)
.headers(upstream_headers)
.body(body_bytes);
let upstream_resp = match upstream_req.send().await {
Ok(resp) => resp,
Err(e) => {
tracing::error!("Upstream request failed: {e}");
return (
StatusCode::BAD_GATEWAY,
axum::Json(serde_json::json!({
"type": "error",
"error": {
"type": "api_error",
"message": format!("Upstream request failed: {e}")
}
})),
)
.into_response();
}
};
let status = StatusCode::from_u16(upstream_resp.status().as_u16())
.unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
let mut response_headers = HeaderMap::new();
for (name, value) in upstream_resp.headers() {
let name_str = name.as_str().to_lowercase();
if matches!(
name_str.as_str(),
"connection" | "transfer-encoding" | "keep-alive"
) {
continue;
}
response_headers.insert(name.clone(), value.clone());
}
let stream = upstream_resp
.bytes_stream()
.map(|chunk| chunk.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e)));
let body = Body::from_stream(stream);
let mut response = Response::new(body);
*response.status_mut() = status;
*response.headers_mut() = response_headers;
response
}