Skip to main content

forge_guardrails/proxy/
server.rs

1//! Lightweight HTTP server with OpenAI-compatible endpoints.
2//!
3//! # Endpoints
4//! - `GET /health`
5//! - `GET /v1/models`
6//! - `POST /v1/chat/completions`
7//! - `OPTIONS /v1/chat/completions`
8//! - `POST /v1/messages`
9//! - `OPTIONS /v1/messages`
10//!
11//! Supports optional request serialization via a single-worker queue for
12//! single-GPU environments.
13
14use std::sync::Arc;
15
16use serde_json::json;
17use tokio::sync::Mutex;
18
19use super::response;
20use crate::clients::base::LLMClient;
21use crate::context::manager::ContextManager;
22
23mod request_handlers;
24#[cfg(test)]
25const MAX_BODY_SIZE: usize = request_handlers::MAX_BODY_SIZE;
26#[cfg(test)]
27mod test_helpers;
28#[cfg(test)]
29pub use test_helpers::{format_anthropic_sse_body, format_sse_body, parse_http_request};
30
31/// HTTP server configuration for the OpenAI-compatible proxy.
32pub struct HTTPServer {
33    /// Host to bind.
34    pub host: String,
35    /// Port to bind.
36    pub port: u16,
37    /// Whether to serialize requests (single-worker queue).
38    pub serialize_requests: bool,
39    /// Maximum retries for tool-call validation.
40    pub max_retries: i32,
41    /// Whether rescue parsing is enabled.
42    pub rescue_enabled: bool,
43    /// The model name reported in responses.
44    pub model_name: String,
45    /// Mutex for request serialization.
46    request_mutex: Arc<Mutex<()>>,
47}
48
49impl HTTPServer {
50    /// Creates a new `HTTPServer` instance with the specified binding and validation options.
51    pub fn new(
52        host: &str,
53        port: u16,
54        serialize_requests: bool,
55        max_retries: i32,
56        rescue_enabled: bool,
57        model_name: &str,
58    ) -> Self {
59        Self {
60            host: host.to_string(),
61            port,
62            serialize_requests,
63            max_retries,
64            rescue_enabled,
65            model_name: model_name.to_string(),
66            request_mutex: Arc::new(Mutex::new(())),
67        }
68    }
69
70    /// Serve the HTTP API using axum on a single-threaded local executor.
71    ///
72    /// Binds to `self.host:self.port`. Because `LLMClient` async methods
73    /// produce non-`Send` futures (native AFIT), this server runs on a
74    /// `tokio::task::LocalSet` which removes the cross-thread `Send` requirement,
75    /// matching Python's single-threaded asyncio model.
76    ///
77    /// Requests are serialized through `self.request_mutex` when
78    /// `serialize_requests=true`, matching Python's `asyncio.Semaphore(1)`.
79    pub fn serve_blocking<C>(
80        self: Arc<Self>,
81        client: Arc<C>,
82        ctx: Arc<Mutex<ContextManager>>,
83    ) -> anyhow::Result<()>
84    where
85        C: LLMClient + 'static,
86    {
87        let rt = tokio::runtime::Builder::new_current_thread()
88            .enable_all()
89            .build()?;
90        let local = tokio::task::LocalSet::new();
91        local.block_on(&rt, async move {
92            let addr: std::net::SocketAddr = format!("{}:{}", self.host, self.port)
93                .parse()
94                .map_err(|e| anyhow::anyhow!("Invalid bind address: {}", e))?;
95            let app = Self::build_router(self.clone(), client, ctx);
96            let listener = tokio::net::TcpListener::bind(addr).await?;
97            axum::serve(listener, app).await?;
98            Ok::<(), anyhow::Error>(())
99        })
100    }
101
102    /// Serve with graceful shutdown triggered by `shutdown_rx` resolving.
103    ///
104    /// Same single-threaded LocalSet model as `serve_blocking`.
105    pub fn serve_blocking_with_shutdown<C>(
106        self: Arc<Self>,
107        client: Arc<C>,
108        ctx: Arc<Mutex<ContextManager>>,
109        shutdown_rx: tokio::sync::oneshot::Receiver<()>,
110    ) -> anyhow::Result<()>
111    where
112        C: LLMClient + 'static,
113    {
114        let rt = tokio::runtime::Builder::new_current_thread()
115            .enable_all()
116            .build()?;
117        let local = tokio::task::LocalSet::new();
118        local.block_on(&rt, async move {
119            let addr: std::net::SocketAddr = format!("{}:{}", self.host, self.port)
120                .parse()
121                .map_err(|e| anyhow::anyhow!("Invalid bind address: {}", e))?;
122            let app = Self::build_router(self.clone(), client, ctx);
123            let listener = tokio::net::TcpListener::bind(addr).await?;
124            axum::serve(listener, app)
125                .with_graceful_shutdown(async move {
126                    let _ = shutdown_rx.await;
127                })
128                .await?;
129            Ok::<(), anyhow::Error>(())
130        })
131    }
132}
133
134/// Shared state passed to axum route handlers.
135#[derive(Clone)]
136pub struct RouterState<C> {
137    /// Server configuration and handler methods.
138    pub server: Arc<HTTPServer>,
139    /// Backend LLM client.
140    pub client: Arc<C>,
141    /// Per-server context manager.
142    pub ctx: Arc<Mutex<ContextManager>>,
143}
144
145/// Return the proxy health response.
146pub async fn health() -> axum::response::Response {
147    response::build_response(200, "application/json", json!({"status": "ok"}).to_string())
148}
149
150/// Return the OpenAI-compatible model list.
151pub async fn models<C>(
152    axum::extract::State(state): axum::extract::State<Arc<RouterState<C>>>,
153) -> axum::response::Response
154where
155    C: LLMClient + 'static,
156{
157    response::build_response(
158        200,
159        "application/json",
160        json!({
161            "object": "list",
162            "data": [{
163                "id": state.server.model_name,
164                "object": "model",
165                "created": 0,
166                "owned_by": "local"
167            }]
168        })
169        .to_string(),
170    )
171}
172
173/// Handle an OpenAI-compatible chat completion request.
174pub async fn chat<C>(
175    axum::extract::State(state): axum::extract::State<Arc<RouterState<C>>>,
176    body: axum::body::Bytes,
177) -> axum::response::Response
178where
179    C: LLMClient + 'static,
180{
181    state
182        .server
183        .handle_chat_completions_response(&body, &state.client, &state.ctx)
184        .await
185}
186
187/// Handle an Anthropic-compatible messages request.
188pub async fn messages<C>(
189    axum::extract::State(state): axum::extract::State<Arc<RouterState<C>>>,
190    body: axum::body::Bytes,
191) -> axum::response::Response
192where
193    C: LLMClient + 'static,
194{
195    state
196        .server
197        .handle_anthropic_messages_response(&body, &state.client, &state.ctx)
198        .await
199}
200
201/// Return the chat-completions CORS preflight response.
202pub async fn opts_chat() -> axum::response::Response {
203    response::cors_preflight_response()
204}
205
206/// Return the messages CORS preflight response.
207pub async fn opts_messages() -> axum::response::Response {
208    response::cors_preflight_response()
209}
210
211impl HTTPServer {
212    fn build_router<C>(
213        server: Arc<Self>,
214        client: Arc<C>,
215        ctx: Arc<Mutex<ContextManager>>,
216    ) -> axum::Router
217    where
218        C: LLMClient + 'static,
219    {
220        use axum::{
221            routing::{get, options, post},
222            Router,
223        };
224
225        let state = Arc::new(RouterState {
226            server,
227            client,
228            ctx,
229        });
230
231        Router::new()
232            .route("/health", get(health))
233            .route("/v1/models", get(models::<C>))
234            .route("/v1/chat/completions", post(chat::<C>))
235            .route("/v1/chat/completions", options(opts_chat))
236            .route("/v1/messages", post(messages::<C>))
237            .route("/v1/messages", options(opts_messages))
238            .with_state(state)
239    }
240
241    /// Build CORS headers for OPTIONS responses.
242    pub fn cors_headers() -> Vec<(&'static str, &'static str)> {
243        response::cors_headers()
244    }
245}
246
247#[cfg(test)]
248mod tests;