1use std::sync::Arc;
15
16use serde_json::json;
17use tokio::sync::Mutex;
18
19use super::response;
20use crate::clients::base::LLMClient;
21use crate::context::manager::ContextManager;
22
23mod request_handlers;
24#[cfg(test)]
25const MAX_BODY_SIZE: usize = request_handlers::MAX_BODY_SIZE;
26#[cfg(test)]
27mod test_helpers;
28#[cfg(test)]
29pub use test_helpers::{format_anthropic_sse_body, format_sse_body, parse_http_request};
30
31pub struct HTTPServer {
33 pub host: String,
35 pub port: u16,
37 pub serialize_requests: bool,
39 pub max_retries: i32,
41 pub rescue_enabled: bool,
43 pub model_name: String,
45 request_mutex: Arc<Mutex<()>>,
47}
48
49impl HTTPServer {
50 pub fn new(
52 host: &str,
53 port: u16,
54 serialize_requests: bool,
55 max_retries: i32,
56 rescue_enabled: bool,
57 model_name: &str,
58 ) -> Self {
59 Self {
60 host: host.to_string(),
61 port,
62 serialize_requests,
63 max_retries,
64 rescue_enabled,
65 model_name: model_name.to_string(),
66 request_mutex: Arc::new(Mutex::new(())),
67 }
68 }
69
70 pub fn serve_blocking<C>(
80 self: Arc<Self>,
81 client: Arc<C>,
82 ctx: Arc<Mutex<ContextManager>>,
83 ) -> anyhow::Result<()>
84 where
85 C: LLMClient + 'static,
86 {
87 let rt = tokio::runtime::Builder::new_current_thread()
88 .enable_all()
89 .build()?;
90 let local = tokio::task::LocalSet::new();
91 local.block_on(&rt, async move {
92 let addr: std::net::SocketAddr = format!("{}:{}", self.host, self.port)
93 .parse()
94 .map_err(|e| anyhow::anyhow!("Invalid bind address: {}", e))?;
95 let app = Self::build_router(self.clone(), client, ctx);
96 let listener = tokio::net::TcpListener::bind(addr).await?;
97 axum::serve(listener, app).await?;
98 Ok::<(), anyhow::Error>(())
99 })
100 }
101
102 pub fn serve_blocking_with_shutdown<C>(
106 self: Arc<Self>,
107 client: Arc<C>,
108 ctx: Arc<Mutex<ContextManager>>,
109 shutdown_rx: tokio::sync::oneshot::Receiver<()>,
110 ) -> anyhow::Result<()>
111 where
112 C: LLMClient + 'static,
113 {
114 let rt = tokio::runtime::Builder::new_current_thread()
115 .enable_all()
116 .build()?;
117 let local = tokio::task::LocalSet::new();
118 local.block_on(&rt, async move {
119 let addr: std::net::SocketAddr = format!("{}:{}", self.host, self.port)
120 .parse()
121 .map_err(|e| anyhow::anyhow!("Invalid bind address: {}", e))?;
122 let app = Self::build_router(self.clone(), client, ctx);
123 let listener = tokio::net::TcpListener::bind(addr).await?;
124 axum::serve(listener, app)
125 .with_graceful_shutdown(async move {
126 let _ = shutdown_rx.await;
127 })
128 .await?;
129 Ok::<(), anyhow::Error>(())
130 })
131 }
132}
133
134#[derive(Clone)]
136pub struct RouterState<C> {
137 pub server: Arc<HTTPServer>,
139 pub client: Arc<C>,
141 pub ctx: Arc<Mutex<ContextManager>>,
143}
144
145pub async fn health() -> axum::response::Response {
147 response::build_response(200, "application/json", json!({"status": "ok"}).to_string())
148}
149
150pub async fn models<C>(
152 axum::extract::State(state): axum::extract::State<Arc<RouterState<C>>>,
153) -> axum::response::Response
154where
155 C: LLMClient + 'static,
156{
157 response::build_response(
158 200,
159 "application/json",
160 json!({
161 "object": "list",
162 "data": [{
163 "id": state.server.model_name,
164 "object": "model",
165 "created": 0,
166 "owned_by": "local"
167 }]
168 })
169 .to_string(),
170 )
171}
172
173pub async fn chat<C>(
175 axum::extract::State(state): axum::extract::State<Arc<RouterState<C>>>,
176 body: axum::body::Bytes,
177) -> axum::response::Response
178where
179 C: LLMClient + 'static,
180{
181 state
182 .server
183 .handle_chat_completions_response(&body, &state.client, &state.ctx)
184 .await
185}
186
187pub async fn messages<C>(
189 axum::extract::State(state): axum::extract::State<Arc<RouterState<C>>>,
190 body: axum::body::Bytes,
191) -> axum::response::Response
192where
193 C: LLMClient + 'static,
194{
195 state
196 .server
197 .handle_anthropic_messages_response(&body, &state.client, &state.ctx)
198 .await
199}
200
201pub async fn opts_chat() -> axum::response::Response {
203 response::cors_preflight_response()
204}
205
206pub async fn opts_messages() -> axum::response::Response {
208 response::cors_preflight_response()
209}
210
211impl HTTPServer {
212 fn build_router<C>(
213 server: Arc<Self>,
214 client: Arc<C>,
215 ctx: Arc<Mutex<ContextManager>>,
216 ) -> axum::Router
217 where
218 C: LLMClient + 'static,
219 {
220 use axum::{
221 routing::{get, options, post},
222 Router,
223 };
224
225 let state = Arc::new(RouterState {
226 server,
227 client,
228 ctx,
229 });
230
231 Router::new()
232 .route("/health", get(health))
233 .route("/v1/models", get(models::<C>))
234 .route("/v1/chat/completions", post(chat::<C>))
235 .route("/v1/chat/completions", options(opts_chat))
236 .route("/v1/messages", post(messages::<C>))
237 .route("/v1/messages", options(opts_messages))
238 .with_state(state)
239 }
240
241 pub fn cors_headers() -> Vec<(&'static str, &'static str)> {
243 response::cors_headers()
244 }
245}
246
247#[cfg(test)]
248mod tests;