use std::sync::Arc;
use serde_json::json;
use tokio::sync::Mutex;
use super::response;
use crate::clients::base::LLMClient;
use crate::context::manager::ContextManager;
mod request_handlers;
#[cfg(test)]
const MAX_BODY_SIZE: usize = request_handlers::MAX_BODY_SIZE;
#[cfg(test)]
mod test_helpers;
#[cfg(test)]
pub use test_helpers::{format_anthropic_sse_body, format_sse_body, parse_http_request};
pub struct HTTPServer {
pub host: String,
pub port: u16,
pub serialize_requests: bool,
pub max_retries: i32,
pub rescue_enabled: bool,
pub model_name: String,
request_mutex: Arc<Mutex<()>>,
}
impl HTTPServer {
pub fn new(
host: &str,
port: u16,
serialize_requests: bool,
max_retries: i32,
rescue_enabled: bool,
model_name: &str,
) -> Self {
Self {
host: host.to_string(),
port,
serialize_requests,
max_retries,
rescue_enabled,
model_name: model_name.to_string(),
request_mutex: Arc::new(Mutex::new(())),
}
}
pub fn serve_blocking<C>(
self: Arc<Self>,
client: Arc<C>,
ctx: Arc<Mutex<ContextManager>>,
) -> anyhow::Result<()>
where
C: LLMClient + 'static,
{
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()?;
let local = tokio::task::LocalSet::new();
local.block_on(&rt, async move {
let addr: std::net::SocketAddr = format!("{}:{}", self.host, self.port)
.parse()
.map_err(|e| anyhow::anyhow!("Invalid bind address: {}", e))?;
let app = Self::build_router(self.clone(), client, ctx);
let listener = tokio::net::TcpListener::bind(addr).await?;
axum::serve(listener, app).await?;
Ok::<(), anyhow::Error>(())
})
}
pub fn serve_blocking_with_shutdown<C>(
self: Arc<Self>,
client: Arc<C>,
ctx: Arc<Mutex<ContextManager>>,
shutdown_rx: tokio::sync::oneshot::Receiver<()>,
) -> anyhow::Result<()>
where
C: LLMClient + 'static,
{
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()?;
let local = tokio::task::LocalSet::new();
local.block_on(&rt, async move {
let addr: std::net::SocketAddr = format!("{}:{}", self.host, self.port)
.parse()
.map_err(|e| anyhow::anyhow!("Invalid bind address: {}", e))?;
let app = Self::build_router(self.clone(), client, ctx);
let listener = tokio::net::TcpListener::bind(addr).await?;
axum::serve(listener, app)
.with_graceful_shutdown(async move {
let _ = shutdown_rx.await;
})
.await?;
Ok::<(), anyhow::Error>(())
})
}
}
#[derive(Clone)]
pub struct RouterState<C> {
pub server: Arc<HTTPServer>,
pub client: Arc<C>,
pub ctx: Arc<Mutex<ContextManager>>,
}
pub async fn health() -> axum::response::Response {
response::build_response(200, "application/json", json!({"status": "ok"}).to_string())
}
pub async fn models<C>(
axum::extract::State(state): axum::extract::State<Arc<RouterState<C>>>,
) -> axum::response::Response
where
C: LLMClient + 'static,
{
response::build_response(
200,
"application/json",
json!({
"object": "list",
"data": [{
"id": state.server.model_name,
"object": "model",
"created": 0,
"owned_by": "local"
}]
})
.to_string(),
)
}
pub async fn chat<C>(
axum::extract::State(state): axum::extract::State<Arc<RouterState<C>>>,
body: axum::body::Bytes,
) -> axum::response::Response
where
C: LLMClient + 'static,
{
state
.server
.handle_chat_completions_response(&body, &state.client, &state.ctx)
.await
}
pub async fn messages<C>(
axum::extract::State(state): axum::extract::State<Arc<RouterState<C>>>,
body: axum::body::Bytes,
) -> axum::response::Response
where
C: LLMClient + 'static,
{
state
.server
.handle_anthropic_messages_response(&body, &state.client, &state.ctx)
.await
}
pub async fn opts_chat() -> axum::response::Response {
response::cors_preflight_response()
}
pub async fn opts_messages() -> axum::response::Response {
response::cors_preflight_response()
}
impl HTTPServer {
fn build_router<C>(
server: Arc<Self>,
client: Arc<C>,
ctx: Arc<Mutex<ContextManager>>,
) -> axum::Router
where
C: LLMClient + 'static,
{
use axum::{
routing::{get, options, post},
Router,
};
let state = Arc::new(RouterState {
server,
client,
ctx,
});
Router::new()
.route("/health", get(health))
.route("/v1/models", get(models::<C>))
.route("/v1/chat/completions", post(chat::<C>))
.route("/v1/chat/completions", options(opts_chat))
.route("/v1/messages", post(messages::<C>))
.route("/v1/messages", options(opts_messages))
.with_state(state)
}
pub fn cors_headers() -> Vec<(&'static str, &'static str)> {
response::cors_headers()
}
}
#[cfg(test)]
mod tests;