use std::sync::Arc;
use axum::{
Json,
body::Body,
extract::State,
http::{Request, StatusCode},
middleware::Next,
response::{IntoResponse, Response},
};
use fusillade::{PostgresRequestManager, ReqwestHttpClient};
use sqlx_pool_router::PoolProvider;
use super::jobs::CreateResponseInput;
use super::store::{self as response_store, ONWARDS_RESPONSE_ID_HEADER, OnwardsDaemonId};
#[derive(Clone)]
pub struct ResponsesMiddlewareState<P: PoolProvider + Clone = sqlx_pool_router::DbPools> {
pub request_manager: Arc<PostgresRequestManager<P, ReqwestHttpClient>>,
pub daemon_id: OnwardsDaemonId,
pub loopback_base_url: String,
pub dwctl_pool: sqlx::PgPool,
pub create_response_job: Arc<underway::Job<CreateResponseInput, crate::tasks::TaskState<P>>>,
pub response_store: Arc<super::store::FusilladeResponseStore<P>>,
pub multi_step_tool_executor: Arc<crate::tool_executor::HttpToolExecutor>,
pub multi_step_http_client: Arc<dyn onwards::client::HttpClient + Send + Sync>,
pub loop_config: onwards::LoopConfig,
}
#[tracing::instrument(skip_all)]
pub async fn responses_middleware<P: PoolProvider + Clone + Send + Sync + 'static>(
State(state): State<ResponsesMiddlewareState<P>>,
req: Request<Body>,
next: Next,
) -> Response {
if !should_intercept(req.method(), req.uri().path()) {
return next.run(req).await;
}
if req.headers().get("x-fusillade-request-id").is_some() {
return next.run(req).await;
}
let (parts, body) = req.into_parts();
let body_bytes = match axum::body::to_bytes(body, usize::MAX).await {
Ok(bytes) => bytes,
Err(e) => {
tracing::error!(error = %e, "Failed to read request body in responses middleware");
return Response::builder().status(StatusCode::BAD_REQUEST).body(Body::empty()).unwrap();
}
};
let mut request_value: serde_json::Value = match serde_json::from_slice(&body_bytes) {
Ok(v) => v,
Err(e) => {
tracing::error!(error = %e, "Failed to parse request body in responses middleware");
return Response::builder().status(StatusCode::BAD_REQUEST).body(Body::empty()).unwrap();
}
};
let model = request_value["model"].as_str().unwrap_or("unknown").to_string();
let model = model.as_str();
let nested_path = parts.uri.path();
let is_responses_api = nested_path.ends_with("/responses");
let is_chat_completions_api = nested_path.ends_with("/chat/completions");
let endpoint = format!("/v1{nested_path}");
let api_key = parts
.headers
.get("authorization")
.and_then(|v| v.to_str().ok())
.and_then(|s| s.strip_prefix("Bearer "))
.map(|s| s.to_string());
if request_value.get("tools").is_none()
&& is_responses_api
&& let Some(key) = api_key.as_deref()
&& let Ok(Some(resolved)) = crate::tool_injection::resolve_tools_for_request(&state.dwctl_pool, key, Some(model)).await
{
let openai_tools = resolved.to_openai_tools_array();
if !openai_tools.is_empty() {
request_value["tools"] = serde_json::Value::Array(openai_tools);
}
}
let requested_tier = resolve_service_tier(request_value["service_tier"].as_str());
let background = if is_responses_api {
request_value["background"].as_bool().unwrap_or(false)
} else {
false
};
let service_tier = if matches!(requested_tier, ServiceTier::Flex) && !is_responses_api && !is_chat_completions_api {
tracing::warn!(
endpoint = %nested_path,
"service_tier:'flex' is not yet supported on this endpoint; falling back to realtime."
);
ServiceTier::Realtime
} else {
requested_tier
};
let is_flex = matches!(service_tier, ServiceTier::Flex);
let has_tools = request_value.get("tools").and_then(|v| v.as_array()).is_some_and(|a| !a.is_empty());
let stream_requested = is_responses_api && !background && request_value["stream"].as_bool().unwrap_or(false);
match warm_path_branch(is_responses_api, is_flex, background, stream_requested, has_tools) {
WarmPathBranch::Stream => {
if let Some(resp) = try_warm_path_stream(&state, &request_value, api_key.as_deref(), model).await {
return resp;
}
}
WarmPathBranch::Blocking => {
if let Some(resp) = try_warm_path_blocking(&state, &request_value, api_key.as_deref(), model).await {
return resp;
}
}
WarmPathBranch::Background => {
if let Some(resp) = try_warm_path_background(&state, &request_value, api_key.as_deref(), model).await {
return resp;
}
}
WarmPathBranch::FallThrough => {}
}
tracing::debug!(
model = %model,
service_tier = %service_tier,
background = background,
endpoint = %endpoint,
"Routing inference request"
);
let request_id = uuid::Uuid::new_v4();
let resp_id = format!("resp_{request_id}");
if matches!(service_tier, ServiceTier::Flex) {
match api_key.as_deref() {
None => {
return Response::builder()
.status(StatusCode::UNAUTHORIZED)
.header("content-type", "application/json")
.body(Body::from(
serde_json::json!({"error": {"message": "API key required", "type": "invalid_request_error"}}).to_string(),
))
.unwrap();
}
Some(key) => {
if let Err(msg) = crate::error_enrichment::validate_api_key_model_access(state.dwctl_pool.clone(), key, model).await {
return Response::builder()
.status(StatusCode::FORBIDDEN)
.header("content-type", "application/json")
.body(Body::from(
serde_json::json!({"error": {"message": msg, "type": "invalid_request_error"}}).to_string(),
))
.unwrap();
}
}
}
}
let needs_sync_attribution = background || matches!(service_tier, ServiceTier::Flex);
let created_by = if needs_sync_attribution {
response_store::lookup_created_by(&state.dwctl_pool, api_key.as_deref()).await
} else {
None
};
match service_tier {
ServiceTier::Realtime => {
let realtime_input = fusillade::CreateRealtimeInput {
request_id,
body: request_value.to_string(),
model: model.to_string(),
endpoint: state.loopback_base_url.clone(),
method: "POST".to_string(),
path: endpoint.clone(),
api_key: api_key.clone().unwrap_or_default(),
created_by: created_by.unwrap_or_default(),
};
handle_realtime(&state, realtime_input, &resp_id, model, background, parts, body_bytes, next).await
}
ServiceTier::Flex => {
let flex_input = fusillade::CreateFlexInput {
request_id,
body: request_value.to_string(),
model: model.to_string(),
endpoint: state.loopback_base_url.clone(),
method: "POST".to_string(),
path: endpoint.clone(),
api_key: api_key.clone().unwrap_or_default(),
created_by: created_by.unwrap_or_default(),
};
if is_chat_completions_api {
handle_chat_completion_flex(&state, flex_input, request_id).await
} else {
handle_flex(&state, flex_input, &resp_id, model, background).await
}
}
}
}
enum ServiceTier {
Realtime,
Flex,
}
impl std::fmt::Display for ServiceTier {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ServiceTier::Realtime => write!(f, "realtime"),
ServiceTier::Flex => write!(f, "flex"),
}
}
}
fn resolve_service_tier(tier: Option<&str>) -> ServiceTier {
match tier {
Some("flex") => ServiceTier::Flex,
_ => ServiceTier::Realtime,
}
}
#[derive(Debug, PartialEq, Eq)]
enum WarmPathBranch {
Stream,
Blocking,
Background,
FallThrough,
}
fn warm_path_branch(is_responses_api: bool, is_flex: bool, background: bool, stream_requested: bool, has_tools: bool) -> WarmPathBranch {
if is_flex {
return WarmPathBranch::FallThrough;
}
if !is_responses_api {
return WarmPathBranch::FallThrough;
}
if !has_tools {
return WarmPathBranch::FallThrough;
}
if stream_requested {
WarmPathBranch::Stream
} else if background {
WarmPathBranch::Background
} else {
WarmPathBranch::Blocking
}
}
#[allow(clippy::too_many_arguments)]
async fn handle_realtime<P: PoolProvider + Clone + Send + Sync + 'static>(
state: &ResponsesMiddlewareState<P>,
realtime_input: fusillade::CreateRealtimeInput,
resp_id: &str,
model: &str,
background: bool,
parts: axum::http::request::Parts,
body_bytes: bytes::Bytes,
next: Next,
) -> Response {
let rm = state.request_manager.clone();
let endpoint_for_header = realtime_input.path.clone();
if background {
if let Err(e) = fusillade::Storage::create_realtime(&*rm, realtime_input).await {
tracing::warn!(error = %e, "Failed to create realtime tracking row");
}
} else {
let job_input = CreateResponseInput {
request_id: realtime_input.request_id,
body: realtime_input.body,
model: realtime_input.model,
base_url: realtime_input.endpoint,
endpoint: realtime_input.path,
api_key: Some(realtime_input.api_key).filter(|s| !s.is_empty()),
};
tracing::debug!(
request_id = %job_input.request_id,
model = %job_input.model,
endpoint = %job_input.endpoint,
"responses_middleware enqueueing create-response job"
);
if let Err(e) = state.create_response_job.enqueue(&job_input).await {
tracing::warn!(error = %e, "Failed to enqueue create-response job");
}
}
let raw_id = resp_id.strip_prefix("resp_").unwrap_or(resp_id);
let mut req = Request::from_parts(parts, Body::from(body_bytes));
req.headers_mut()
.insert("x-fusillade-request-id", raw_id.parse().expect("response_id is valid header value"));
req.headers_mut().insert(
ONWARDS_RESPONSE_ID_HEADER,
resp_id.parse().expect("response_id is valid header value"),
);
if let Ok(value) = endpoint_for_header.parse() {
req.headers_mut().insert("x-onwards-endpoint", value);
}
if let Ok(value) = model.parse() {
req.headers_mut().insert("x-onwards-model", value);
}
if background {
let response_body = serde_json::json!({
"id": resp_id,
"object": "response",
"status": "in_progress",
"model": model,
"background": true,
"output": [],
});
tokio::spawn(async move {
let response = next.run(req).await;
let (_parts, body) = response.into_parts();
let _ = axum::body::to_bytes(body, usize::MAX).await;
});
(StatusCode::ACCEPTED, Json(response_body)).into_response()
} else {
next.run(req).await
}
}
async fn handle_flex<P: PoolProvider + Clone + Send + Sync + 'static>(
state: &ResponsesMiddlewareState<P>,
flex_input: fusillade::CreateFlexInput,
resp_id: &str,
model: &str,
background: bool,
) -> Response {
if let Err(e) = fusillade::Storage::create_flex(&*state.request_manager, flex_input).await {
tracing::error!(error = %e, "Failed to create flex row in fusillade");
return Response::builder()
.status(StatusCode::INTERNAL_SERVER_ERROR)
.header("content-type", "application/json")
.body(Body::from(
serde_json::json!({
"error": {
"message": "Failed to enqueue request",
"type": "server_error",
"code": 500,
}
})
.to_string(),
))
.unwrap();
}
if background {
let response_body = serde_json::json!({
"id": resp_id,
"object": "response",
"status": "queued",
"model": model,
"background": true,
"service_tier": "flex",
"output": [],
});
tracing::debug!(response_id = %resp_id, "Enqueued flex request");
(StatusCode::ACCEPTED, Json(response_body)).into_response()
} else {
tracing::debug!(response_id = %resp_id, "Blocking flex — polling until daemon completes");
let poll_interval = std::time::Duration::from_millis(500);
let timeout = std::time::Duration::from_secs(3600);
match response_store::poll_until_complete(&state.request_manager, resp_id, poll_interval, timeout).await {
Ok(response_obj) => {
let status_code = if response_obj["status"].as_str() == Some("completed") {
StatusCode::OK
} else {
response_obj["error"]["code"]
.as_u64()
.and_then(|c| StatusCode::from_u16(c as u16).ok())
.unwrap_or(StatusCode::INTERNAL_SERVER_ERROR)
};
(status_code, Json(response_obj)).into_response()
}
Err(e) => {
tracing::error!(error = %e, response_id = %resp_id, "Blocking flex poll failed");
let response_body = serde_json::json!({
"error": {
"message": format!("Request timed out: {e}"),
"type": "server_error",
}
});
(StatusCode::GATEWAY_TIMEOUT, Json(response_body)).into_response()
}
}
}
}
async fn handle_chat_completion_flex<P: PoolProvider + Clone + Send + Sync + 'static>(
state: &ResponsesMiddlewareState<P>,
flex_input: fusillade::CreateFlexInput,
request_id: uuid::Uuid,
) -> Response {
if let Err(e) = fusillade::Storage::create_flex(&*state.request_manager, flex_input).await {
tracing::error!(error = %e, "Failed to create flex chat-completions batch in fusillade");
return Response::builder()
.status(StatusCode::INTERNAL_SERVER_ERROR)
.header("content-type", "application/json")
.body(Body::from(
serde_json::json!({
"error": {
"message": "Failed to enqueue request",
"type": "server_error",
"code": 500,
}
})
.to_string(),
))
.unwrap();
}
let poll_interval = std::time::Duration::from_millis(500);
let timeout = std::time::Duration::from_secs(3600);
match response_store::poll_until_terminal(&state.request_manager, request_id, poll_interval, timeout).await {
Ok(detail) => {
let (status, body) = response_store::detail_to_chat_completion_object(&detail);
let status_code = StatusCode::from_u16(status).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
tracing::debug!(request_id = %request_id, %status_code, "Flex chat-completions terminal");
(status_code, Json(body)).into_response()
}
Err(e) => {
tracing::error!(error = %e, request_id = %request_id, "Blocking flex chat-completions poll failed");
(
StatusCode::GATEWAY_TIMEOUT,
Json(serde_json::json!({
"error": {
"message": format!("Request timed out: {e}"),
"type": "server_error",
"code": 504,
}
})),
)
.into_response()
}
}
}
pub(crate) fn should_intercept(method: &axum::http::Method, path: &str) -> bool {
method == axum::http::Method::POST
&& (path.ends_with("/responses") || path.ends_with("/chat/completions") || path.ends_with("/embeddings"))
}
async fn try_warm_path_stream<P: PoolProvider + Clone + Send + Sync + 'static>(
state: &ResponsesMiddlewareState<P>,
request_value: &serde_json::Value,
api_key: Option<&str>,
model: &str,
) -> Option<Response> {
let api_key = api_key?;
let (head_step_uuid, resolved, upstream) = warm_path_setup(state, request_value, api_key, model).await?;
let sse = super::streaming::run_inline_streaming(
state.response_store.clone(),
state.multi_step_tool_executor.clone(),
resolved,
state.multi_step_http_client.clone(),
upstream,
state.loop_config,
head_step_uuid.to_string(),
model.to_string(),
);
Some(sse.into_response())
}
async fn try_warm_path_blocking<P: PoolProvider + Clone + Send + Sync + 'static>(
state: &ResponsesMiddlewareState<P>,
request_value: &serde_json::Value,
api_key: Option<&str>,
model: &str,
) -> Option<Response> {
let api_key = api_key?;
let (request_id, resolved, upstream) = match warm_path_setup(state, request_value, api_key, model).await {
Some(s) => s,
None => return None,
};
let result = super::streaming::run_inline_blocking(
state.response_store.clone(),
state.multi_step_tool_executor.clone(),
resolved,
state.multi_step_http_client.clone(),
upstream,
state.loop_config,
request_id.to_string(),
model.to_string(),
)
.await;
let (status, body) = match result {
Ok(json) => (StatusCode::OK, json),
Err(err_payload) => (StatusCode::BAD_GATEWAY, serde_json::json!({"error": err_payload})),
};
Some((status, Json(body)).into_response())
}
async fn try_warm_path_background<P: PoolProvider + Clone + Send + Sync + 'static>(
state: &ResponsesMiddlewareState<P>,
request_value: &serde_json::Value,
api_key: Option<&str>,
model: &str,
) -> Option<Response> {
let api_key = api_key?;
let (request_id, resolved, upstream) = match warm_path_setup(state, request_value, api_key, model).await {
Some(s) => s,
None => return None,
};
let resp_id = format!("resp_{request_id}");
let response_body = serde_json::json!({
"id": resp_id,
"object": "response",
"status": "in_progress",
"model": model,
"background": true,
"output": [],
});
let response_store = state.response_store.clone();
let tool_executor = state.multi_step_tool_executor.clone();
let http_client = state.multi_step_http_client.clone();
let loop_config = state.loop_config;
let model_str = model.to_string();
let request_id_str = request_id.to_string();
tokio::spawn(async move {
let _ = super::streaming::run_inline_blocking(
response_store,
tool_executor,
resolved,
http_client,
upstream,
loop_config,
request_id_str,
model_str,
)
.await;
});
Some((StatusCode::ACCEPTED, Json(response_body)).into_response())
}
async fn warm_path_setup<P: PoolProvider + Clone + Send + Sync + 'static>(
state: &ResponsesMiddlewareState<P>,
request_value: &serde_json::Value,
api_key: &str,
model: &str,
) -> Option<(uuid::Uuid, Arc<crate::tool_executor::ResolvedToolSet>, onwards::UpstreamTarget)> {
let created_by = response_store::lookup_created_by(&state.dwctl_pool, Some(api_key)).await;
let resolved = match crate::tool_injection::resolve_tools_for_request(&state.dwctl_pool, api_key, Some(model)).await {
Ok(Some(set)) => Arc::new(set),
Ok(None) => Arc::new(crate::tool_executor::ResolvedToolSet::new(
std::collections::HashMap::new(),
std::collections::HashMap::new(),
)),
Err(e) => {
tracing::warn!(error = %e, "warm-path: tool resolution failed; running with no tools");
Arc::new(crate::tool_executor::ResolvedToolSet::new(
std::collections::HashMap::new(),
std::collections::HashMap::new(),
))
}
};
let resolved_tool_names = resolved.tools.keys().cloned().collect();
let head_step_uuid = uuid::Uuid::new_v4();
let pending = response_store::PendingResponseInput {
body: request_value.to_string(),
api_key: Some(api_key.to_string()),
created_by: created_by.clone(),
base_url: state.loopback_base_url.clone(),
resolved_tool_names,
};
if let Err(e) = state.response_store.register_pending_with_id(head_step_uuid, pending) {
tracing::error!(
error = %e,
request_id = %head_step_uuid,
"warm-path: failed to register pending input — aborting warm path",
);
return None;
}
let realtime_input = fusillade::CreateRealtimeInput {
request_id: head_step_uuid,
body: request_value.to_string(),
model: model.to_string(),
endpoint: state.loopback_base_url.clone(),
method: "POST".to_string(),
path: "/v1/responses".to_string(),
api_key: api_key.to_string(),
created_by: created_by.unwrap_or_default(),
};
if let Err(e) = fusillade::Storage::create_realtime(&*state.request_manager, realtime_input).await {
tracing::error!(
error = %e,
request_id = %head_step_uuid,
"warm-path: failed to create /v1/responses tracking row — aborting warm path",
);
state.response_store.unregister_pending(&head_step_uuid.to_string());
return None;
}
let upstream = onwards::UpstreamTarget {
url: format!("{}/v1/chat/completions", state.loopback_base_url),
api_key: Some(api_key.to_string()),
};
Some((head_step_uuid, resolved, upstream))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_should_intercept_responses() {
assert!(should_intercept(&axum::http::Method::POST, "/v1/responses"));
assert!(should_intercept(&axum::http::Method::POST, "/responses"));
}
#[test]
fn test_should_intercept_chat_completions() {
assert!(should_intercept(&axum::http::Method::POST, "/v1/chat/completions"));
}
#[test]
fn test_should_intercept_embeddings() {
assert!(should_intercept(&axum::http::Method::POST, "/v1/embeddings"));
}
#[test]
fn test_should_not_intercept_get() {
assert!(!should_intercept(&axum::http::Method::GET, "/v1/responses"));
}
#[test]
fn test_should_not_intercept_models() {
assert!(!should_intercept(&axum::http::Method::GET, "/v1/models"));
assert!(!should_intercept(&axum::http::Method::POST, "/v1/models"));
}
#[test]
fn test_should_not_intercept_batches() {
assert!(!should_intercept(&axum::http::Method::POST, "/v1/batches"));
}
#[test]
fn test_should_not_intercept_files() {
assert!(!should_intercept(&axum::http::Method::POST, "/v1/files"));
}
#[test]
fn test_resolve_service_tier_priority_is_realtime() {
assert!(matches!(resolve_service_tier(Some("priority")), ServiceTier::Realtime));
}
#[test]
fn test_resolve_service_tier_default_is_realtime() {
assert!(matches!(resolve_service_tier(Some("default")), ServiceTier::Realtime));
}
#[test]
fn test_resolve_service_tier_auto_is_realtime() {
assert!(matches!(resolve_service_tier(Some("auto")), ServiceTier::Realtime));
}
#[test]
fn test_resolve_service_tier_none_is_realtime() {
assert!(matches!(resolve_service_tier(None), ServiceTier::Realtime));
}
#[test]
fn test_resolve_service_tier_flex() {
assert!(matches!(resolve_service_tier(Some("flex")), ServiceTier::Flex));
}
#[test]
fn warm_path_branch_flex_responses_falls_through_to_handle_flex() {
for &background in &[false, true] {
for &stream in &[false, true] {
for &has_tools in &[false, true] {
assert_eq!(
warm_path_branch(true, true, background, stream, has_tools),
WarmPathBranch::FallThrough,
"flex /v1/responses must fall through (background={background}, stream={stream}, has_tools={has_tools})"
);
}
}
}
}
#[test]
fn warm_path_branch_realtime_responses_with_tools_picks_correct_warm_branch() {
assert_eq!(warm_path_branch(true, false, false, true, true), WarmPathBranch::Stream);
assert_eq!(warm_path_branch(true, false, true, false, true), WarmPathBranch::Background);
assert_eq!(warm_path_branch(true, false, false, false, true), WarmPathBranch::Blocking);
}
#[test]
fn warm_path_branch_realtime_responses_without_tools_falls_through() {
for &background in &[false, true] {
for &stream in &[false, true] {
assert_eq!(
warm_path_branch(true, false, background, stream, false),
WarmPathBranch::FallThrough,
"tool-free realtime /v1/responses must fall through (background={background}, stream={stream})"
);
}
}
}
#[test]
fn warm_path_branch_chat_completions_always_falls_through() {
for &has_tools in &[false, true] {
assert_eq!(warm_path_branch(false, false, false, false, has_tools), WarmPathBranch::FallThrough);
assert_eq!(warm_path_branch(false, true, false, false, has_tools), WarmPathBranch::FallThrough);
assert_eq!(warm_path_branch(false, false, false, true, has_tools), WarmPathBranch::FallThrough);
}
}
}