mod client;
mod handler;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use axum::body::Body;
use axum::extract::Json;
use axum::http::{header, HeaderMap, Method, Request};
use axum::response::{IntoResponse, Response};
use axum::routing::post;
use axum::Router;
use tower::{Layer, Service};
pub use client::ForwardingError;
pub use handler::stream_event_to_sse;
use crate::config::TranslationConfig;
#[derive(Clone, Debug)]
pub struct AnthropicCompatConfig {
pub backend_url: String,
pub api_key: String,
pub client_api_key: String,
pub translation: TranslationConfig,
}
impl AnthropicCompatConfig {
pub fn builder() -> AnthropicCompatConfigBuilder {
AnthropicCompatConfigBuilder {
backend_url: String::new(),
api_key: String::new(),
client_api_key: String::new(),
translation: TranslationConfig::default(),
}
}
}
pub struct AnthropicCompatConfigBuilder {
backend_url: String,
api_key: String,
client_api_key: String,
translation: TranslationConfig,
}
impl AnthropicCompatConfigBuilder {
pub fn backend_url(mut self, url: impl Into<String>) -> Self {
self.backend_url = url.into();
self
}
pub fn api_key(mut self, key: impl Into<String>) -> Self {
self.api_key = key.into();
self
}
pub fn client_api_key(mut self, key: impl Into<String>) -> Self {
self.client_api_key = key.into();
self
}
pub fn translation(mut self, config: TranslationConfig) -> Self {
self.translation = config;
self
}
pub fn build(self) -> AnthropicCompatConfig {
AnthropicCompatConfig {
backend_url: self.backend_url,
api_key: self.api_key,
client_api_key: self.client_api_key,
translation: self.translation,
}
}
}
pub(crate) struct MiddlewareState {
pub(crate) config: AnthropicCompatConfig,
pub(crate) client: client::ForwardingClient,
}
fn make_state(config: AnthropicCompatConfig) -> Arc<MiddlewareState> {
let client = client::ForwardingClient::new(&config.backend_url, &config.api_key);
Arc::new(MiddlewareState { config, client })
}
fn request_has_valid_client_api_key(headers: &HeaderMap, expected: &str) -> bool {
if expected.is_empty() {
return false;
}
if let Some(x_api_key) = headers.get("x-api-key").and_then(|v| v.to_str().ok()) {
if x_api_key == expected {
return true;
}
}
headers
.get(header::AUTHORIZATION)
.and_then(|v| v.to_str().ok())
.and_then(|v| v.strip_prefix("Bearer "))
.is_some_and(|token| token == expected)
}
fn unauthorized_response() -> Response {
let err = crate::mapping::errors_map::create_anthropic_error(
crate::anthropic::ErrorType::AuthenticationError,
"Invalid or missing API key".to_string(),
None,
);
(axum::http::StatusCode::UNAUTHORIZED, Json(err)).into_response()
}
pub fn anthropic_compat_router(config: AnthropicCompatConfig) -> Router {
let state = make_state(config);
Router::new().route(
"/v1/messages",
post(
move |headers: HeaderMap, Json(body): Json<crate::anthropic::MessageCreateRequest>| {
let state = Arc::clone(&state);
async move {
if !request_has_valid_client_api_key(&headers, &state.config.client_api_key) {
return unauthorized_response();
}
handler::handle_messages(state, body).await
}
},
),
)
}
#[derive(Clone)]
pub struct AnthropicTranslationLayer {
state: Arc<MiddlewareState>,
}
impl AnthropicTranslationLayer {
pub fn new(config: AnthropicCompatConfig) -> Self {
Self {
state: make_state(config),
}
}
}
impl<S> Layer<S> for AnthropicTranslationLayer {
type Service = AnthropicTranslationService<S>;
fn layer(&self, inner: S) -> Self::Service {
AnthropicTranslationService {
inner,
state: Arc::clone(&self.state),
}
}
}
#[derive(Clone)]
pub struct AnthropicTranslationService<S> {
inner: S,
state: Arc<MiddlewareState>,
}
impl<S> Service<Request<Body>> for AnthropicTranslationService<S>
where
S: Service<Request<Body>, Response = Response> + Clone + Send + 'static,
S::Future: Send + 'static,
S::Error: Into<Box<dyn std::error::Error + Send + Sync>> + Send,
{
type Response = Response;
type Error = S::Error;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, req: Request<Body>) -> Self::Future {
if req.method() == Method::POST && req.uri().path() == "/v1/messages" {
let state = Arc::clone(&self.state);
let headers = req.headers().clone();
Box::pin(async move {
if !request_has_valid_client_api_key(&headers, &state.config.client_api_key) {
return Ok(unauthorized_response());
}
let body_bytes =
match axum::body::to_bytes(req.into_body(), 32 * 1024 * 1024).await {
Ok(b) => b,
Err(_) => {
let err = crate::mapping::errors_map::create_anthropic_error(
crate::anthropic::ErrorType::InvalidRequestError,
"Request body too large".to_string(),
None,
);
return Ok((axum::http::StatusCode::PAYLOAD_TOO_LARGE, Json(err))
.into_response());
}
};
let anthropic_req: crate::anthropic::MessageCreateRequest =
match serde_json::from_slice(&body_bytes) {
Ok(r) => r,
Err(e) => {
let err = crate::mapping::errors_map::create_anthropic_error(
crate::anthropic::ErrorType::InvalidRequestError,
format!("Invalid JSON: {e}"),
None,
);
return Ok(
(axum::http::StatusCode::BAD_REQUEST, Json(err)).into_response()
);
}
};
Ok(handler::handle_messages(state, anthropic_req).await)
})
} else {
let fut = self.inner.call(req);
Box::pin(fut)
}
}
}