use std::sync::Arc;
use async_trait::async_trait;
use fusillade::{BatchInput, PostgresRequestManager, RequestId, RequestTemplateInput, ReqwestHttpClient, Storage};
use onwards::{ResponseStore, StoreError};
use sqlx_pool_router::PoolProvider;
use uuid::Uuid;
pub const ONWARDS_RESPONSE_ID_HEADER: &str = "x-onwards-response-id";
#[derive(Debug, Clone, Copy)]
pub struct OnwardsDaemonId(pub Uuid);
pub struct FusilladeResponseStore<P: PoolProvider + Clone> {
request_manager: Arc<PostgresRequestManager<P, ReqwestHttpClient>>,
}
impl<P: PoolProvider + Clone> FusilladeResponseStore<P> {
pub fn new(request_manager: Arc<PostgresRequestManager<P, ReqwestHttpClient>>) -> Self {
Self { request_manager }
}
pub async fn get_response(&self, response_id: &str) -> Result<Option<serde_json::Value>, StoreError> {
let id = parse_response_id(response_id)?;
match self.request_manager.get_request_detail(RequestId(id)).await {
Ok(detail) => Ok(Some(detail_to_response_object(&detail))),
Err(fusillade::FusilladeError::RequestNotFound(_)) => Ok(None),
Err(e) => Err(StoreError::StorageError(format!("Failed to fetch request: {e}"))),
}
}
}
pub async fn fail_response<P: PoolProvider + Clone>(
request_manager: &PostgresRequestManager<P, ReqwestHttpClient>,
response_id: &str,
error: &str,
) -> Result<(), StoreError> {
let id = parse_response_id(response_id)?;
request_manager
.fail_request(RequestId(id), error)
.await
.map_err(|e| StoreError::StorageError(format!("Failed to fail request: {e}")))?;
Ok(())
}
pub async fn request_exists<P: PoolProvider + Clone>(
request_manager: &PostgresRequestManager<P, ReqwestHttpClient>,
request_id: Uuid,
) -> Result<bool, StoreError> {
match request_manager.get_request_detail(RequestId(request_id)).await {
Ok(_) => Ok(true),
Err(fusillade::FusilladeError::RequestNotFound(_)) => Ok(false),
Err(e) => Err(StoreError::StorageError(format!("Failed to check request existence: {e}"))),
}
}
pub struct CreateContext<'a> {
pub batch_id: Uuid,
pub request_id: Uuid,
pub request_body: &'a str,
pub model: &'a str,
pub endpoint: &'a str,
pub base_url: &'a str,
pub api_key: Option<&'a str>,
}
pub async fn complete_response_idempotent<P: PoolProvider + Clone>(
request_manager: &PostgresRequestManager<P, ReqwestHttpClient>,
dwctl_pool: &sqlx::PgPool,
response_id: &str,
response_body: &str,
status_code: u16,
create_ctx: CreateContext<'_>,
) -> Result<(), StoreError> {
let id = parse_response_id(response_id)?;
match request_manager.complete_request(RequestId(id), response_body, status_code).await {
Ok(()) => Ok(()),
Err(fusillade::FusilladeError::RequestNotFound(_)) => {
tracing::info!(
response_id = %response_id,
model = %create_ctx.model,
endpoint = %create_ctx.endpoint,
"complete-response synthesizing row (create-response hasn't run yet)"
);
if create_ctx.endpoint.is_empty() {
return Err(StoreError::StorageError(
"Cannot synthesize request row: empty endpoint in CreateContext (x-onwards-endpoint header missing upstream)".into(),
));
}
let created_by = lookup_created_by(dwctl_pool, create_ctx.api_key).await;
let batch_input = fusillade::CreateSingleRequestBatchInput {
batch_id: Some(create_ctx.batch_id),
request_id: create_ctx.request_id,
body: create_ctx.request_body.to_string(),
model: create_ctx.model.to_string(),
base_url: create_ctx.base_url.to_string(),
endpoint: create_ctx.endpoint.to_string(),
completion_window: "0s".to_string(),
initial_state: "processing".to_string(),
api_key: create_ctx.api_key.map(String::from),
created_by,
};
match request_manager.create_single_request_batch(batch_input).await {
Ok(_) => {
tracing::info!(
response_id = %response_id,
"Synthetic create from complete-response succeeded — row now exists in 'processing'"
);
}
Err(e) => {
tracing::info!(
response_id = %response_id,
error = %e,
"Synthetic create from complete-response failed (likely create-response won the race) — proceeding to UPDATE"
);
}
}
match request_manager.complete_request(RequestId(id), response_body, status_code).await {
Ok(()) => {
tracing::info!(response_id = %response_id, "Second-attempt UPDATE succeeded — row now 'completed'");
Ok(())
}
Err(e) => {
tracing::warn!(response_id = %response_id, error = %e, "Second-attempt UPDATE failed");
Err(StoreError::StorageError(format!("Failed to complete after create: {e}")))
}
}
}
Err(e) => Err(StoreError::StorageError(format!("Failed to complete request: {e}"))),
}
}
pub async fn poll_until_complete<P: PoolProvider + Clone>(
request_manager: &PostgresRequestManager<P, ReqwestHttpClient>,
response_id: &str,
poll_interval: std::time::Duration,
timeout: std::time::Duration,
) -> Result<serde_json::Value, StoreError> {
let id = parse_response_id(response_id)?;
let start = std::time::Instant::now();
loop {
match request_manager.get_request_detail(RequestId(id)).await {
Ok(detail) => match detail.status.as_str() {
"completed" | "failed" | "canceled" => {
return Ok(detail_to_response_object(&detail));
}
_ => {}
},
Err(fusillade::FusilladeError::RequestNotFound(_)) => {}
Err(e) => {
return Err(StoreError::StorageError(format!("Failed to poll request: {e}")));
}
}
if start.elapsed() >= timeout {
return Err(StoreError::StorageError(format!(
"Timeout waiting for request {response_id} to complete after {:?}",
timeout
)));
}
tokio::time::sleep(poll_interval).await;
}
}
pub async fn lookup_created_by(pool: &sqlx::PgPool, api_key: Option<&str>) -> Option<String> {
let key = api_key?;
match sqlx::query("SELECT user_id FROM public.api_keys WHERE secret = $1 AND is_deleted = false LIMIT 1")
.bind(key)
.fetch_optional(pool)
.await
{
Ok(Some(row)) => {
use sqlx::Row;
let user_id: Uuid = row.get("user_id");
Some(user_id.to_string())
}
Ok(None) => {
tracing::warn!(key_prefix = &key[..8.min(key.len())], "API key not found for attribution");
None
}
Err(e) => {
tracing::error!(error = %e, "Failed to look up API key for attribution");
None
}
}
}
pub async fn create_batch_of_1<P: PoolProvider + Clone>(
request_manager: &PostgresRequestManager<P, ReqwestHttpClient>,
request: &serde_json::Value,
model: &str,
base_url: &str,
path: &str,
completion_window: &str,
api_key: Option<&str>,
) -> Result<(String, Uuid), StoreError> {
let pool = request_manager.pool();
let body = request.to_string();
let created_by = lookup_created_by(pool, api_key).await.unwrap_or_default();
let template = RequestTemplateInput {
custom_id: None,
endpoint: base_url.to_string(),
method: "POST".to_string(),
path: path.to_string(),
body,
model: model.to_string(),
api_key: String::new(),
};
let file_id = request_manager
.create_file("responses_api_single".into(), None, vec![template])
.await
.map_err(|e| StoreError::StorageError(format!("Failed to create file: {e}")))?;
let batch = request_manager
.create_batch(BatchInput {
file_id,
endpoint: path.to_string(),
completion_window: completion_window.to_string(),
metadata: None,
created_by: if created_by.is_empty() { None } else { Some(created_by) },
api_key_id: None,
api_key: api_key.map(|s| s.to_string()),
total_requests: Some(1),
})
.await
.map_err(|e| StoreError::StorageError(format!("Failed to create batch: {e}")))?;
let requests = request_manager
.get_batch_requests(batch.id)
.await
.map_err(|e| StoreError::StorageError(format!("Failed to get batch requests: {e}")))?;
let request_id = requests
.first()
.map(|r| *r.id())
.ok_or_else(|| StoreError::StorageError("Batch created with no requests".into()))?;
let response_id = format!("resp_{request_id}");
tracing::debug!(
response_id = %response_id,
batch_id = %batch.id,
completion_window = %completion_window,
"Created batch of 1 for async processing"
);
Ok((response_id, request_id))
}
fn extract_upstream_error(status: u16, body: &str) -> (&'static str, String) {
if let Some(message) = parse_openai_error(body) {
return (status_to_error_type(status), message);
}
(status_to_error_type(status), body.to_string())
}
fn parse_failure_error(err: &str) -> (&'static str, String, Option<u16>) {
if let Ok(reason) = serde_json::from_str::<serde_json::Value>(err)
&& let Some(details) = reason.get("details")
{
let status = details.get("status").and_then(|s| s.as_u64()).and_then(|s| u16::try_from(s).ok());
let error_type = status_to_error_type(status.unwrap_or(500));
if let Some(body) = details.get("body").and_then(|b| b.as_str()) {
if let Some(message) = parse_openai_error(body) {
return (error_type, message, status);
}
return (error_type, body.to_string(), status);
}
}
if let Some(rest) = err.strip_prefix("Upstream returned ")
&& let Some(colon_pos) = rest.find(": ")
&& let Ok(status) = rest[..colon_pos].parse::<u16>()
{
let body = &rest[colon_pos + 2..];
if let Some(message) = parse_openai_error(body) {
return (status_to_error_type(status), message, Some(status));
}
return (status_to_error_type(status), body.to_string(), Some(status));
}
if let Some(message) = parse_openai_error(err) {
return ("server_error", message, None);
}
("server_error", err.to_string(), None)
}
fn parse_openai_error(body: &str) -> Option<String> {
let parsed: serde_json::Value = serde_json::from_str(body).ok()?;
let error = parsed.get("error")?;
let message = error.get("message")?.as_str()?;
Some(message.to_string())
}
fn status_to_error_type(status: u16) -> &'static str {
match status {
400 => "invalid_request_error",
401 => "authentication_error",
402 => "insufficient_credits",
403 => "permission_error",
404 => "not_found_error",
429 => "rate_limit_error",
_ => "server_error",
}
}
fn parse_response_id(response_id: &str) -> Result<Uuid, StoreError> {
let uuid_str = response_id.strip_prefix("resp_").unwrap_or(response_id);
Uuid::parse_str(uuid_str).map_err(|e| StoreError::NotFound(format!("Invalid response ID: {e}")))
}
fn state_to_status(state: &str) -> &'static str {
match state {
"pending" => "queued",
"claimed" | "processing" => "in_progress",
"completed" => "completed",
"failed" => "failed",
"canceled" => "cancelled",
_ => "failed",
}
}
fn detail_to_response_object(detail: &fusillade::RequestDetail) -> serde_json::Value {
let status = state_to_status(&detail.status);
let background = detail
.body
.as_deref()
.and_then(|b| serde_json::from_str::<serde_json::Value>(b).ok())
.and_then(|v| v.get("background")?.as_bool())
.unwrap_or(false);
let mut resp = serde_json::json!({
"id": format!("resp_{}", detail.id),
"object": "response",
"created_at": detail.created_at.timestamp(),
"status": status,
"model": detail.model,
"background": background,
"output": [],
});
if status == "completed" {
let response_status = match detail.response_status {
Some(s) => u16::try_from(s).unwrap_or(500),
None => 200,
};
let is_error_response = response_status >= 400;
if is_error_response {
resp["status"] = serde_json::json!("failed");
let (error_type, message) = if let Some(ref body) = detail.response_body {
extract_upstream_error(response_status, body)
} else {
(
status_to_error_type(response_status),
format!("Upstream returned {response_status}"),
)
};
resp["error"] = serde_json::json!({
"type": error_type,
"code": response_status,
"message": message,
});
} else if let Some(ref body) = detail.response_body
&& let Ok(parsed) = serde_json::from_str::<serde_json::Value>(body)
{
if let Some(output) = parsed.get("output") {
resp["output"] = output.clone();
}
if let Some(usage) = parsed.get("usage") {
resp["usage"] = usage.clone();
}
if parsed.get("choices").is_some() {
resp["output"] = serde_json::json!([{
"type": "message",
"role": "assistant",
"content": parsed
}]);
}
}
resp["completed_at"] = serde_json::json!(detail.completed_at.map(|t| t.timestamp()));
}
if status == "failed"
&& let Some(ref err) = detail.error
{
let (error_type, message, status_code) = parse_failure_error(err);
let mut error_obj = serde_json::json!({
"type": error_type,
"message": message,
});
if let Some(code) = status_code {
error_obj["code"] = serde_json::json!(code);
}
resp["error"] = error_obj;
}
resp
}
#[async_trait]
impl<P: PoolProvider + Clone + Send + Sync + 'static> ResponseStore for FusilladeResponseStore<P> {
async fn store(&self, response: &serde_json::Value) -> Result<String, StoreError> {
let id = response.get("id").and_then(|v| v.as_str()).unwrap_or("").to_string();
Ok(id)
}
async fn get_context(&self, response_id: &str) -> Result<Option<serde_json::Value>, StoreError> {
self.get_response(response_id).await
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_response_id_with_prefix() {
let uuid = Uuid::new_v4();
let id = format!("resp_{uuid}");
let parsed = parse_response_id(&id).unwrap();
assert_eq!(parsed, uuid);
}
#[test]
fn test_parse_response_id_without_prefix() {
let uuid = Uuid::new_v4();
let parsed = parse_response_id(&uuid.to_string()).unwrap();
assert_eq!(parsed, uuid);
}
#[test]
fn test_parse_response_id_invalid() {
let result = parse_response_id("not-a-uuid");
assert!(result.is_err());
assert!(matches!(result, Err(StoreError::NotFound(_))));
}
#[test]
fn test_state_to_status_mapping() {
assert_eq!(state_to_status("pending"), "queued");
assert_eq!(state_to_status("claimed"), "in_progress");
assert_eq!(state_to_status("processing"), "in_progress");
assert_eq!(state_to_status("completed"), "completed");
assert_eq!(state_to_status("failed"), "failed");
assert_eq!(state_to_status("canceled"), "cancelled");
assert_eq!(state_to_status("unknown"), "failed");
}
#[test]
fn test_store_extracts_id_from_response() {
let response = serde_json::json!({
"id": "resp_12345678-1234-1234-1234-123456789abc",
"status": "completed",
});
let id = response.get("id").and_then(|v| v.as_str()).unwrap_or("");
assert_eq!(id, "resp_12345678-1234-1234-1234-123456789abc");
}
#[test]
fn test_store_handles_missing_id() {
let response = serde_json::json!({"status": "completed"});
let id = response.get("id").and_then(|v| v.as_str()).unwrap_or("");
assert_eq!(id, "");
}
#[test]
fn test_extract_upstream_error_openai_format() {
let body = r#"{"error":{"message":"Forbidden","type":"invalid_request_error","param":null,"code":"forbidden"}}"#;
let (error_type, message) = extract_upstream_error(403, body);
assert_eq!(error_type, "permission_error");
assert_eq!(message, "Forbidden");
}
#[test]
fn test_extract_upstream_error_rate_limit() {
let body = r#"{"error":{"message":"Rate limit exceeded","type":"rate_limit_error","code":"rate_limit"}}"#;
let (error_type, message) = extract_upstream_error(429, body);
assert_eq!(error_type, "rate_limit_error");
assert_eq!(message, "Rate limit exceeded");
}
#[test]
fn test_extract_upstream_error_plain_text() {
let (error_type, message) = extract_upstream_error(402, "Account balance too low");
assert_eq!(error_type, "insufficient_credits");
assert_eq!(message, "Account balance too low");
}
#[test]
fn test_extract_upstream_error_server_error() {
let body = r#"{"error":{"message":"Internal error"}}"#;
let (error_type, message) = extract_upstream_error(500, body);
assert_eq!(error_type, "server_error");
assert_eq!(message, "Internal error");
}
#[test]
fn test_parse_failure_error_legacy_format() {
let err = r#"{"type":"NonRetriableHttpStatus","details":{"status":403,"body":"{\"error\":{\"message\":\"Forbidden\",\"type\":\"invalid_request_error\",\"param\":null,\"code\":\"forbidden\"}}"}}"#;
let (error_type, message, status_code) = parse_failure_error(err);
assert_eq!(error_type, "permission_error");
assert_eq!(message, "Forbidden");
assert_eq!(status_code, Some(403));
}
#[test]
fn test_parse_failure_error_plain_string() {
let (error_type, message, status_code) = parse_failure_error("some unknown error");
assert_eq!(error_type, "server_error");
assert_eq!(message, "some unknown error");
assert_eq!(status_code, None);
}
#[test]
fn test_parse_failure_error_legacy_upstream_returned_format() {
let err =
r#"Upstream returned 403: {"error":{"message":"Forbidden","type":"invalid_request_error","param":null,"code":"forbidden"}}"#;
let (error_type, message, status_code) = parse_failure_error(err);
assert_eq!(error_type, "permission_error");
assert_eq!(message, "Forbidden");
assert_eq!(status_code, Some(403));
}
#[test]
fn test_status_to_error_type_mapping() {
assert_eq!(status_to_error_type(400), "invalid_request_error");
assert_eq!(status_to_error_type(401), "authentication_error");
assert_eq!(status_to_error_type(402), "insufficient_credits");
assert_eq!(status_to_error_type(403), "permission_error");
assert_eq!(status_to_error_type(404), "not_found_error");
assert_eq!(status_to_error_type(429), "rate_limit_error");
assert_eq!(status_to_error_type(500), "server_error");
assert_eq!(status_to_error_type(503), "server_error");
}
}