use std::sync::Arc;
use axum::{
Json,
body::Body,
extract::State,
http::{Request, StatusCode},
middleware::Next,
response::{IntoResponse, Response},
};
use fusillade::{PostgresRequestManager, ReqwestHttpClient};
use sqlx_pool_router::PoolProvider;
use super::jobs::CreateResponseInput;
use super::store::{self as response_store, ONWARDS_RESPONSE_ID_HEADER, OnwardsDaemonId};
#[derive(Clone)]
pub struct ResponsesMiddlewareState<P: PoolProvider + Clone = sqlx_pool_router::DbPools> {
pub request_manager: Arc<PostgresRequestManager<P, ReqwestHttpClient>>,
pub daemon_id: OnwardsDaemonId,
pub loopback_base_url: String,
pub dwctl_pool: sqlx::PgPool,
pub create_response_job: Arc<underway::Job<CreateResponseInput, crate::tasks::TaskState<P>>>,
}
#[tracing::instrument(skip_all)]
pub async fn responses_middleware<P: PoolProvider + Clone + Send + Sync + 'static>(
State(state): State<ResponsesMiddlewareState<P>>,
req: Request<Body>,
next: Next,
) -> Response {
if !should_intercept(req.method(), req.uri().path()) {
return next.run(req).await;
}
if req.headers().get("x-fusillade-request-id").is_some() {
return next.run(req).await;
}
let (parts, body) = req.into_parts();
let body_bytes = match axum::body::to_bytes(body, usize::MAX).await {
Ok(bytes) => bytes,
Err(e) => {
tracing::error!(error = %e, "Failed to read request body in responses middleware");
return Response::builder().status(StatusCode::BAD_REQUEST).body(Body::empty()).unwrap();
}
};
let request_value: serde_json::Value = match serde_json::from_slice(&body_bytes) {
Ok(v) => v,
Err(e) => {
tracing::error!(error = %e, "Failed to parse request body in responses middleware");
return Response::builder().status(StatusCode::BAD_REQUEST).body(Body::empty()).unwrap();
}
};
let model = request_value["model"].as_str().unwrap_or("unknown");
let nested_path = parts.uri.path();
let is_responses_api = nested_path.ends_with("/responses");
let endpoint = format!("/v1{nested_path}");
let api_key = parts
.headers
.get("authorization")
.and_then(|v| v.to_str().ok())
.and_then(|s| s.strip_prefix("Bearer "))
.map(|s| s.to_string());
let (service_tier, background) = if is_responses_api {
let tier = resolve_service_tier(request_value["service_tier"].as_str());
let bg = request_value["background"].as_bool().unwrap_or(false);
(tier, bg)
} else {
(ServiceTier::Realtime, false)
};
tracing::debug!(
model = %model,
service_tier = %service_tier,
background = background,
endpoint = %endpoint,
"Routing inference request"
);
let request_id = uuid::Uuid::new_v4();
let batch_id = uuid::Uuid::new_v4();
let resp_id = format!("resp_{request_id}");
let completion_window = match service_tier {
ServiceTier::Flex => "1h",
ServiceTier::Realtime => "0s",
};
if matches!(service_tier, ServiceTier::Flex) {
match api_key.as_deref() {
None => {
return Response::builder()
.status(StatusCode::UNAUTHORIZED)
.header("content-type", "application/json")
.body(Body::from(
serde_json::json!({"error": {"message": "API key required", "type": "invalid_request_error"}}).to_string(),
))
.unwrap();
}
Some(key) => {
if let Err(msg) = crate::error_enrichment::validate_api_key_model_access(state.dwctl_pool.clone(), key, model).await {
return Response::builder()
.status(StatusCode::FORBIDDEN)
.header("content-type", "application/json")
.body(Body::from(
serde_json::json!({"error": {"message": msg, "type": "invalid_request_error"}}).to_string(),
))
.unwrap();
}
}
}
}
let needs_sync_attribution = background || matches!(service_tier, ServiceTier::Flex);
let created_by = if needs_sync_attribution {
response_store::lookup_created_by(&state.dwctl_pool, api_key.as_deref()).await
} else {
None
};
let initial_state = match service_tier {
ServiceTier::Realtime => "processing", ServiceTier::Flex => "pending", };
let batch_input = fusillade::CreateSingleRequestBatchInput {
batch_id: Some(batch_id),
request_id,
body: request_value.to_string(),
model: model.to_string(),
base_url: state.loopback_base_url.clone(),
endpoint: endpoint.clone(),
completion_window: completion_window.to_string(),
initial_state: initial_state.to_string(),
api_key: api_key.clone(),
created_by,
};
match service_tier {
ServiceTier::Realtime => handle_realtime(&state, batch_input, batch_id, &resp_id, model, background, parts, body_bytes, next).await,
ServiceTier::Flex => handle_flex(&state, batch_input, &resp_id, model, background).await,
}
}
enum ServiceTier {
Realtime,
Flex,
}
impl std::fmt::Display for ServiceTier {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ServiceTier::Realtime => write!(f, "realtime"),
ServiceTier::Flex => write!(f, "flex"),
}
}
}
fn resolve_service_tier(tier: Option<&str>) -> ServiceTier {
match tier {
Some("flex") => ServiceTier::Flex,
_ => ServiceTier::Realtime,
}
}
#[allow(clippy::too_many_arguments)]
async fn handle_realtime<P: PoolProvider + Clone + Send + Sync + 'static>(
state: &ResponsesMiddlewareState<P>,
batch_input: fusillade::CreateSingleRequestBatchInput,
batch_id: uuid::Uuid,
resp_id: &str,
model: &str,
background: bool,
parts: axum::http::request::Parts,
body_bytes: bytes::Bytes,
next: Next,
) -> Response {
let rm = state.request_manager.clone();
let endpoint_for_header = batch_input.endpoint.clone();
if background {
if let Err(e) = fusillade::Storage::create_single_request_batch(&*rm, batch_input).await {
tracing::warn!(error = %e, "Failed to create realtime tracking batch");
}
} else {
let job_input = CreateResponseInput {
batch_id,
request_id: batch_input.request_id,
body: batch_input.body,
model: batch_input.model,
base_url: batch_input.base_url,
endpoint: batch_input.endpoint,
api_key: batch_input.api_key,
};
tracing::debug!(
request_id = %job_input.request_id,
model = %job_input.model,
endpoint = %job_input.endpoint,
"responses_middleware enqueueing create-response job"
);
if let Err(e) = state.create_response_job.enqueue(&job_input).await {
tracing::warn!(error = %e, "Failed to enqueue create-response job");
}
}
let raw_id = resp_id.strip_prefix("resp_").unwrap_or(resp_id);
let mut req = Request::from_parts(parts, Body::from(body_bytes));
req.headers_mut()
.insert("x-fusillade-request-id", raw_id.parse().expect("response_id is valid header value"));
req.headers_mut().insert(
"x-fusillade-batch-id",
batch_id.to_string().parse().expect("batch_id is valid header value"),
);
req.headers_mut().insert(
ONWARDS_RESPONSE_ID_HEADER,
resp_id.parse().expect("response_id is valid header value"),
);
if let Ok(value) = endpoint_for_header.parse() {
req.headers_mut().insert("x-onwards-endpoint", value);
}
if let Ok(value) = model.parse() {
req.headers_mut().insert("x-onwards-model", value);
}
if background {
let response_body = serde_json::json!({
"id": resp_id,
"object": "response",
"status": "in_progress",
"model": model,
"background": true,
"output": [],
});
tokio::spawn(async move {
let response = next.run(req).await;
let (_parts, body) = response.into_parts();
let _ = axum::body::to_bytes(body, usize::MAX).await;
});
(StatusCode::ACCEPTED, Json(response_body)).into_response()
} else {
next.run(req).await
}
}
async fn handle_flex<P: PoolProvider + Clone + Send + Sync + 'static>(
state: &ResponsesMiddlewareState<P>,
batch_input: fusillade::CreateSingleRequestBatchInput,
resp_id: &str,
model: &str,
background: bool,
) -> Response {
if let Err(e) = fusillade::Storage::create_single_request_batch(&*state.request_manager, batch_input).await {
tracing::error!(error = %e, "Failed to create flex batch in fusillade");
return Response::builder()
.status(StatusCode::INTERNAL_SERVER_ERROR)
.body(Body::from(
serde_json::json!({
"error": {
"message": "Failed to enqueue request",
"type": "server_error",
}
})
.to_string(),
))
.unwrap();
}
if background {
let response_body = serde_json::json!({
"id": resp_id,
"object": "response",
"status": "queued",
"model": model,
"background": true,
"service_tier": "flex",
"output": [],
});
tracing::debug!(response_id = %resp_id, "Enqueued flex request");
(StatusCode::ACCEPTED, Json(response_body)).into_response()
} else {
tracing::debug!(response_id = %resp_id, "Blocking flex — polling until daemon completes");
let poll_interval = std::time::Duration::from_millis(500);
let timeout = std::time::Duration::from_secs(3600);
match response_store::poll_until_complete(&state.request_manager, resp_id, poll_interval, timeout).await {
Ok(response_obj) => {
let status_code = if response_obj["status"].as_str() == Some("completed") {
StatusCode::OK
} else {
response_obj["error"]["code"]
.as_u64()
.and_then(|c| StatusCode::from_u16(c as u16).ok())
.unwrap_or(StatusCode::INTERNAL_SERVER_ERROR)
};
(status_code, Json(response_obj)).into_response()
}
Err(e) => {
tracing::error!(error = %e, response_id = %resp_id, "Blocking flex poll failed");
let response_body = serde_json::json!({
"error": {
"message": format!("Request timed out: {e}"),
"type": "server_error",
}
});
(StatusCode::GATEWAY_TIMEOUT, Json(response_body)).into_response()
}
}
}
}
pub(crate) fn should_intercept(method: &axum::http::Method, path: &str) -> bool {
method == axum::http::Method::POST
&& (path.ends_with("/responses") || path.ends_with("/chat/completions") || path.ends_with("/embeddings"))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_should_intercept_responses() {
assert!(should_intercept(&axum::http::Method::POST, "/v1/responses"));
assert!(should_intercept(&axum::http::Method::POST, "/responses"));
}
#[test]
fn test_should_intercept_chat_completions() {
assert!(should_intercept(&axum::http::Method::POST, "/v1/chat/completions"));
}
#[test]
fn test_should_intercept_embeddings() {
assert!(should_intercept(&axum::http::Method::POST, "/v1/embeddings"));
}
#[test]
fn test_should_not_intercept_get() {
assert!(!should_intercept(&axum::http::Method::GET, "/v1/responses"));
}
#[test]
fn test_should_not_intercept_models() {
assert!(!should_intercept(&axum::http::Method::GET, "/v1/models"));
assert!(!should_intercept(&axum::http::Method::POST, "/v1/models"));
}
#[test]
fn test_should_not_intercept_batches() {
assert!(!should_intercept(&axum::http::Method::POST, "/v1/batches"));
}
#[test]
fn test_should_not_intercept_files() {
assert!(!should_intercept(&axum::http::Method::POST, "/v1/files"));
}
#[test]
fn test_resolve_service_tier_priority_is_realtime() {
assert!(matches!(resolve_service_tier(Some("priority")), ServiceTier::Realtime));
}
#[test]
fn test_resolve_service_tier_default_is_realtime() {
assert!(matches!(resolve_service_tier(Some("default")), ServiceTier::Realtime));
}
#[test]
fn test_resolve_service_tier_auto_is_realtime() {
assert!(matches!(resolve_service_tier(Some("auto")), ServiceTier::Realtime));
}
#[test]
fn test_resolve_service_tier_none_is_realtime() {
assert!(matches!(resolve_service_tier(None), ServiceTier::Realtime));
}
#[test]
fn test_resolve_service_tier_flex() {
assert!(matches!(resolve_service_tier(Some("flex")), ServiceTier::Flex));
}
}