use std::collections::{HashMap, HashSet};
use std::sync::{Arc, RwLock};
use async_trait::async_trait;
use fusillade::{
BatchInput, CreateSingleRequestBatchInput, CreateStepInput, PostgresRequestManager, PostgresResponseStepManager, RequestId,
RequestTemplateInput, ReqwestHttpClient, ResponseStep, ResponseStepStore, StepId, StepKind as FusilladeStepKind,
StepState as FusilladeStepState, Storage,
};
use onwards::{
ChainStep, MultiStepStore, RecordedStep, ResponseStore, StepDescriptor, StepKind as OnwardsStepKind, StepState as OnwardsStepState,
StoreError,
};
use sqlx_pool_router::PoolProvider;
use uuid::Uuid;
#[derive(Debug, Clone)]
pub struct PendingResponseInput {
pub body: String,
pub api_key: Option<String>,
pub created_by: Option<String>,
pub base_url: String,
pub resolved_tool_names: HashSet<String>,
}
pub const ONWARDS_RESPONSE_ID_HEADER: &str = "x-onwards-response-id";
#[derive(Debug, Clone, Copy)]
pub struct OnwardsDaemonId(pub Uuid);
pub struct FusilladeResponseStore<P: PoolProvider + Clone> {
request_manager: Arc<PostgresRequestManager<P, ReqwestHttpClient>>,
step_manager: Option<Arc<PostgresResponseStepManager<P>>>,
pending_inputs: Arc<RwLock<HashMap<String, PendingResponseInput>>>,
}
impl<P: PoolProvider + Clone> FusilladeResponseStore<P> {
pub fn new(request_manager: Arc<PostgresRequestManager<P, ReqwestHttpClient>>) -> Self {
Self {
request_manager,
step_manager: None,
pending_inputs: Arc::new(RwLock::new(HashMap::new())),
}
}
pub fn with_step_manager(mut self, step_manager: Arc<PostgresResponseStepManager<P>>) -> Self {
self.step_manager = Some(step_manager);
self
}
pub fn register_pending(&self, input: PendingResponseInput) -> Uuid {
let head_step_uuid = Uuid::new_v4();
let key = head_step_uuid.to_string();
if let Ok(mut guard) = self.pending_inputs.write() {
guard.insert(key, input);
}
head_step_uuid
}
pub fn unregister_pending(&self, request_id: &str) {
let key = parse_response_id(request_id)
.map(|u| u.to_string())
.unwrap_or_else(|_| request_id.to_string());
if let Ok(mut guard) = self.pending_inputs.write() {
guard.remove(&key);
}
}
fn pending_input(&self, request_id: &str) -> Result<PendingResponseInput, StoreError> {
let key = parse_response_id(request_id)?.to_string();
self.pending_inputs
.read()
.map_err(|_| StoreError::StorageError("pending_inputs lock poisoned".into()))?
.get(&key)
.cloned()
.ok_or_else(|| {
StoreError::StorageError(format!(
"no pending input registered for response {request_id} — warm path didn't stash it (or it was unregistered)"
))
})
}
pub fn request_manager(&self) -> &PostgresRequestManager<P, ReqwestHttpClient> {
&self.request_manager
}
fn require_step_manager(&self) -> Result<&PostgresResponseStepManager<P>, StoreError> {
self.step_manager.as_deref().ok_or_else(|| {
StoreError::StorageError(
"FusilladeResponseStore was constructed without a step manager — multi-step \
orchestration methods require with_step_manager(...) at construction time"
.into(),
)
})
}
pub async fn get_response(&self, response_id: &str) -> Result<Option<serde_json::Value>, StoreError> {
let parsed_uuid = parse_response_id(response_id)?;
if let Some(step_manager) = self.step_manager.as_deref()
&& let Some(head_step) = step_manager.get_step(StepId(parsed_uuid)).await.map_err(map_fusillade_err)?
{
let Some(sub_request_id) = head_step.request_id else {
return Ok(None);
};
let detail = match self.request_manager.get_request_detail(sub_request_id).await {
Ok(d) => d,
Err(fusillade::FusilladeError::RequestNotFound(_)) => return Ok(None),
Err(e) => return Err(StoreError::StorageError(format!("fetch head sub-request: {e}"))),
};
let mut resp = detail_to_response_object(&detail);
resp["id"] = serde_json::Value::String(format!("resp_{parsed_uuid}"));
return Ok(Some(resp));
}
match self.request_manager.get_request_detail(RequestId(parsed_uuid)).await {
Ok(detail) => Ok(Some(detail_to_response_object(&detail))),
Err(fusillade::FusilladeError::RequestNotFound(_)) => Ok(None),
Err(e) => Err(StoreError::StorageError(format!("Failed to fetch request: {e}"))),
}
}
}
fn parse_step_id(raw: &str) -> Result<StepId, StoreError> {
Uuid::parse_str(raw)
.map(StepId::from)
.map_err(|_| StoreError::NotFound(raw.to_string()))
}
fn map_step_kind(kind: OnwardsStepKind) -> FusilladeStepKind {
match kind {
OnwardsStepKind::ModelCall => FusilladeStepKind::ModelCall,
OnwardsStepKind::ToolCall => FusilladeStepKind::ToolCall,
}
}
fn map_kind_back(kind: FusilladeStepKind) -> OnwardsStepKind {
match kind {
FusilladeStepKind::ModelCall => OnwardsStepKind::ModelCall,
FusilladeStepKind::ToolCall => OnwardsStepKind::ToolCall,
}
}
fn map_state_back(state: FusilladeStepState) -> OnwardsStepState {
match state {
FusilladeStepState::Pending => OnwardsStepState::Pending,
FusilladeStepState::Processing => OnwardsStepState::Processing,
FusilladeStepState::Completed => OnwardsStepState::Completed,
FusilladeStepState::Failed => OnwardsStepState::Failed,
FusilladeStepState::Canceled => OnwardsStepState::Canceled,
}
}
fn step_to_chain(step: ResponseStep) -> ChainStep {
ChainStep {
id: step.id.0.to_string(),
kind: map_kind_back(step.step_kind),
state: map_state_back(step.state),
sequence: step.step_sequence,
prev_step_id: step.prev_step_id.map(|s| s.0.to_string()),
parent_step_id: step.parent_step_id.map(|s| s.0.to_string()),
response_payload: step.response_payload,
error: step.error,
}
}
fn map_fusillade_err(e: fusillade::FusilladeError) -> StoreError {
StoreError::StorageError(format!("fusillade: {e}"))
}
pub async fn fail_response<P: PoolProvider + Clone>(
request_manager: &PostgresRequestManager<P, ReqwestHttpClient>,
response_id: &str,
error: &str,
) -> Result<(), StoreError> {
let id = parse_response_id(response_id)?;
request_manager
.fail_request(RequestId(id), error, 500)
.await
.map_err(|e| StoreError::StorageError(format!("Failed to fail request: {e}")))?;
Ok(())
}
pub async fn request_exists<P: PoolProvider + Clone>(
request_manager: &PostgresRequestManager<P, ReqwestHttpClient>,
request_id: Uuid,
) -> Result<bool, StoreError> {
match request_manager.get_request_detail(RequestId(request_id)).await {
Ok(_) => Ok(true),
Err(fusillade::FusilladeError::RequestNotFound(_)) => Ok(false),
Err(e) => Err(StoreError::StorageError(format!("Failed to check request existence: {e}"))),
}
}
pub struct CreateContext<'a> {
pub batch_id: Uuid,
pub request_id: Uuid,
pub request_body: &'a str,
pub model: &'a str,
pub endpoint: &'a str,
pub base_url: &'a str,
pub api_key: Option<&'a str>,
}
pub async fn complete_response_idempotent<P: PoolProvider + Clone>(
request_manager: &PostgresRequestManager<P, ReqwestHttpClient>,
dwctl_pool: &sqlx::PgPool,
response_id: &str,
response_body: &str,
status_code: u16,
create_ctx: CreateContext<'_>,
) -> Result<(), StoreError> {
let id = parse_response_id(response_id)?;
match request_manager.complete_request(RequestId(id), response_body, status_code).await {
Ok(()) => Ok(()),
Err(fusillade::FusilladeError::RequestNotFound(_)) => {
tracing::info!(
response_id = %response_id,
model = %create_ctx.model,
endpoint = %create_ctx.endpoint,
"complete-response synthesizing row (create-response hasn't run yet)"
);
if create_ctx.endpoint.is_empty() {
return Err(StoreError::StorageError(
"Cannot synthesize request row: empty endpoint in CreateContext (x-onwards-endpoint header missing upstream)".into(),
));
}
let created_by = lookup_created_by(dwctl_pool, create_ctx.api_key).await;
let batch_input = fusillade::CreateSingleRequestBatchInput {
batch_id: Some(create_ctx.batch_id),
request_id: create_ctx.request_id,
body: create_ctx.request_body.to_string(),
model: create_ctx.model.to_string(),
base_url: create_ctx.base_url.to_string(),
endpoint: create_ctx.endpoint.to_string(),
completion_window: "0s".to_string(),
initial_state: "processing".to_string(),
api_key: create_ctx.api_key.map(String::from),
created_by,
};
match request_manager.create_single_request_batch(batch_input).await {
Ok(_) => {
tracing::info!(
response_id = %response_id,
"Synthetic create from complete-response succeeded — row now exists in 'processing'"
);
}
Err(e) => {
tracing::info!(
response_id = %response_id,
error = %e,
"Synthetic create from complete-response failed (likely create-response won the race) — proceeding to UPDATE"
);
}
}
match request_manager.complete_request(RequestId(id), response_body, status_code).await {
Ok(()) => {
tracing::info!(response_id = %response_id, "Second-attempt UPDATE succeeded — row now 'completed'");
Ok(())
}
Err(e) => {
tracing::warn!(response_id = %response_id, error = %e, "Second-attempt UPDATE failed");
Err(StoreError::StorageError(format!("Failed to complete after create: {e}")))
}
}
}
Err(e) => Err(StoreError::StorageError(format!("Failed to complete request: {e}"))),
}
}
pub async fn poll_until_complete<P: PoolProvider + Clone>(
request_manager: &PostgresRequestManager<P, ReqwestHttpClient>,
response_id: &str,
poll_interval: std::time::Duration,
timeout: std::time::Duration,
) -> Result<serde_json::Value, StoreError> {
let id = parse_response_id(response_id)?;
let start = std::time::Instant::now();
loop {
match request_manager.get_request_detail(RequestId(id)).await {
Ok(detail) => match detail.status.as_str() {
"completed" | "failed" | "canceled" => {
return Ok(detail_to_response_object(&detail));
}
_ => {}
},
Err(fusillade::FusilladeError::RequestNotFound(_)) => {}
Err(e) => {
return Err(StoreError::StorageError(format!("Failed to poll request: {e}")));
}
}
if start.elapsed() >= timeout {
return Err(StoreError::StorageError(format!(
"Timeout waiting for request {response_id} to complete after {:?}",
timeout
)));
}
tokio::time::sleep(poll_interval).await;
}
}
pub async fn lookup_created_by(pool: &sqlx::PgPool, api_key: Option<&str>) -> Option<String> {
let key = api_key?;
match sqlx::query("SELECT user_id FROM public.api_keys WHERE secret = $1 AND is_deleted = false LIMIT 1")
.bind(key)
.fetch_optional(pool)
.await
{
Ok(Some(row)) => {
use sqlx::Row;
let user_id: Uuid = row.get("user_id");
Some(user_id.to_string())
}
Ok(None) => {
tracing::warn!(key_prefix = &key[..8.min(key.len())], "API key not found for attribution");
None
}
Err(e) => {
tracing::error!(error = %e, "Failed to look up API key for attribution");
None
}
}
}
pub async fn create_batch_of_1<P: PoolProvider + Clone>(
request_manager: &PostgresRequestManager<P, ReqwestHttpClient>,
request: &serde_json::Value,
model: &str,
base_url: &str,
path: &str,
completion_window: &str,
api_key: Option<&str>,
) -> Result<(String, Uuid), StoreError> {
let pool = request_manager.pool();
let body = request.to_string();
let created_by = lookup_created_by(pool, api_key).await.unwrap_or_default();
let template = RequestTemplateInput {
custom_id: None,
endpoint: base_url.to_string(),
method: "POST".to_string(),
path: path.to_string(),
body,
model: model.to_string(),
api_key: String::new(),
};
let file_id = request_manager
.create_file("responses_api_single".into(), None, vec![template])
.await
.map_err(|e| StoreError::StorageError(format!("Failed to create file: {e}")))?;
let batch = request_manager
.create_batch(BatchInput {
file_id,
endpoint: path.to_string(),
completion_window: completion_window.to_string(),
metadata: None,
created_by: if created_by.is_empty() { None } else { Some(created_by) },
api_key_id: None,
api_key: api_key.map(|s| s.to_string()),
total_requests: Some(1),
})
.await
.map_err(|e| StoreError::StorageError(format!("Failed to create batch: {e}")))?;
let requests = request_manager
.get_batch_requests(batch.id)
.await
.map_err(|e| StoreError::StorageError(format!("Failed to get batch requests: {e}")))?;
let request_id = requests
.first()
.map(|r| *r.id())
.ok_or_else(|| StoreError::StorageError("Batch created with no requests".into()))?;
let response_id = format!("resp_{request_id}");
tracing::debug!(
response_id = %response_id,
batch_id = %batch.id,
completion_window = %completion_window,
"Created batch of 1 for async processing"
);
Ok((response_id, request_id))
}
fn extract_upstream_error(status: u16, body: &str) -> (&'static str, String) {
if let Some(message) = parse_openai_error(body) {
return (status_to_error_type(status), message);
}
(status_to_error_type(status), body.to_string())
}
fn parse_failure_error(err: &str) -> (&'static str, String, Option<u16>) {
if let Ok(reason) = serde_json::from_str::<serde_json::Value>(err)
&& let Some(details) = reason.get("details")
{
let status = details.get("status").and_then(|s| s.as_u64()).and_then(|s| u16::try_from(s).ok());
let error_type = status_to_error_type(status.unwrap_or(500));
if let Some(body) = details.get("body").and_then(|b| b.as_str()) {
if let Some(message) = parse_openai_error(body) {
return (error_type, message, status);
}
return (error_type, body.to_string(), status);
}
}
if let Some(rest) = err.strip_prefix("Upstream returned ")
&& let Some(colon_pos) = rest.find(": ")
&& let Ok(status) = rest[..colon_pos].parse::<u16>()
{
let body = &rest[colon_pos + 2..];
if let Some(message) = parse_openai_error(body) {
return (status_to_error_type(status), message, Some(status));
}
return (status_to_error_type(status), body.to_string(), Some(status));
}
if let Some(message) = parse_openai_error(err) {
return ("server_error", message, None);
}
("server_error", err.to_string(), None)
}
fn parse_openai_error(body: &str) -> Option<String> {
let parsed: serde_json::Value = serde_json::from_str(body).ok()?;
let error = parsed.get("error")?;
let message = error.get("message")?.as_str()?;
Some(message.to_string())
}
fn status_to_error_type(status: u16) -> &'static str {
match status {
400 => "invalid_request_error",
401 => "authentication_error",
402 => "insufficient_credits",
403 => "permission_error",
404 => "not_found_error",
429 => "rate_limit_error",
_ => "server_error",
}
}
fn parse_response_id(response_id: &str) -> Result<Uuid, StoreError> {
let uuid_str = response_id.strip_prefix("resp_").unwrap_or(response_id);
Uuid::parse_str(uuid_str).map_err(|e| StoreError::NotFound(format!("Invalid response ID: {e}")))
}
fn state_to_status(state: &str) -> &'static str {
match state {
"pending" => "queued",
"claimed" | "processing" => "in_progress",
"completed" => "completed",
"failed" => "failed",
"canceled" => "cancelled",
_ => "failed",
}
}
fn detail_to_response_object(detail: &fusillade::RequestDetail) -> serde_json::Value {
let status = state_to_status(&detail.status);
let background = detail
.body
.as_deref()
.and_then(|b| serde_json::from_str::<serde_json::Value>(b).ok())
.and_then(|v| v.get("background")?.as_bool())
.unwrap_or(false);
let mut resp = serde_json::json!({
"id": format!("resp_{}", detail.id),
"object": "response",
"created_at": detail.created_at.timestamp(),
"status": status,
"model": detail.model,
"background": background,
"output": [],
});
if status == "completed" {
let response_status = match detail.response_status {
Some(s) => u16::try_from(s).unwrap_or(500),
None => 200,
};
let is_error_response = response_status >= 400;
if is_error_response {
resp["status"] = serde_json::json!("failed");
let (error_type, message) = if let Some(ref body) = detail.response_body {
extract_upstream_error(response_status, body)
} else {
(
status_to_error_type(response_status),
format!("Upstream returned {response_status}"),
)
};
resp["error"] = serde_json::json!({
"type": error_type,
"code": response_status,
"message": message,
});
} else if let Some(ref body) = detail.response_body
&& let Ok(parsed) = serde_json::from_str::<serde_json::Value>(body)
{
if let Some(output) = parsed.get("output") {
resp["output"] = output.clone();
}
if let Some(usage) = parsed.get("usage") {
resp["usage"] = usage.clone();
}
if parsed.get("choices").is_some() {
resp["output"] = serde_json::json!([{
"type": "message",
"role": "assistant",
"content": parsed
}]);
}
}
resp["completed_at"] = serde_json::json!(detail.completed_at.map(|t| t.timestamp()));
}
if status == "failed"
&& let Some(ref err) = detail.error
{
let (error_type, message, status_code) = parse_failure_error(err);
let mut error_obj = serde_json::json!({
"type": error_type,
"message": message,
});
if let Some(code) = status_code {
error_obj["code"] = serde_json::json!(code);
}
resp["error"] = error_obj;
}
resp
}
#[async_trait]
impl<P: PoolProvider + Clone + Send + Sync + 'static> ResponseStore for FusilladeResponseStore<P> {
async fn store(&self, response: &serde_json::Value) -> Result<String, StoreError> {
let id = response.get("id").and_then(|v| v.as_str()).unwrap_or("").to_string();
Ok(id)
}
async fn get_context(&self, response_id: &str) -> Result<Option<serde_json::Value>, StoreError> {
self.get_response(response_id).await
}
}
#[async_trait]
impl<P: PoolProvider + Clone + Send + Sync + 'static> MultiStepStore for FusilladeResponseStore<P> {
async fn next_action_for(&self, request_id: &str, scope_parent: Option<&str>) -> Result<onwards::NextAction, StoreError> {
let pending = self.pending_input(request_id)?;
let parsed = super::transition::parse_parent_request(&pending.body).map_err(StoreError::StorageError)?;
let chain = <Self as MultiStepStore>::list_chain(self, request_id, scope_parent).await?;
Ok(super::transition::decide_next_action(&parsed, &chain, &pending.resolved_tool_names))
}
async fn record_step(
&self,
request_id: &str,
scope_parent: Option<&str>,
prev_step: Option<&str>,
descriptor: &StepDescriptor,
) -> Result<RecordedStep, StoreError> {
let step_manager = self.require_step_manager()?;
let head_step_uuid = parse_response_id(request_id)?;
let prev_step_id = prev_step.map(parse_step_id).transpose()?;
let chain = step_manager.list_chain(StepId(head_step_uuid)).await.map_err(map_fusillade_err)?;
let is_head = chain.is_empty() && scope_parent.is_none();
let new_step_id = if is_head { Some(head_step_uuid) } else { None };
let parent_step_id = if is_head { None } else { Some(StepId(head_step_uuid)) };
let sequence = chain.iter().map(|s| s.step_sequence).max().unwrap_or(0) + 1;
let req_id = match descriptor.kind {
OnwardsStepKind::ModelCall => Some(RequestId(self.create_sub_request_row(request_id, descriptor).await?)),
OnwardsStepKind::ToolCall => None,
};
let id = step_manager
.create_step(CreateStepInput {
id: new_step_id,
request_id: req_id,
prev_step_id,
parent_step_id,
step_kind: map_step_kind(descriptor.kind),
step_sequence: sequence,
request_payload: descriptor.request_payload.clone(),
})
.await
.map_err(map_fusillade_err)?;
Ok(RecordedStep {
id: id.0.to_string(),
sequence,
})
}
async fn mark_step_processing(&self, step_id: &str) -> Result<(), StoreError> {
let step_manager = self.require_step_manager()?;
step_manager
.mark_step_processing(parse_step_id(step_id)?)
.await
.map_err(map_fusillade_err)
}
async fn complete_step(&self, step_id: &str, payload: &serde_json::Value) -> Result<(), StoreError> {
let step_manager = self.require_step_manager()?;
step_manager
.complete_step(parse_step_id(step_id)?, payload.clone())
.await
.map_err(map_fusillade_err)
}
async fn fail_step(&self, step_id: &str, error: &serde_json::Value) -> Result<(), StoreError> {
let step_manager = self.require_step_manager()?;
step_manager
.fail_step(parse_step_id(step_id)?, error.clone())
.await
.map_err(map_fusillade_err)
}
async fn list_chain(&self, request_id: &str, _scope_parent: Option<&str>) -> Result<Vec<ChainStep>, StoreError> {
let step_manager = self.require_step_manager()?;
let head_step_uuid = parse_response_id(request_id)?;
let steps = step_manager.list_chain(StepId(head_step_uuid)).await.map_err(map_fusillade_err)?;
Ok(steps.into_iter().map(step_to_chain).collect())
}
async fn assemble_response(&self, request_id: &str) -> Result<serde_json::Value, StoreError> {
let chain = <Self as MultiStepStore>::list_chain(self, request_id, None).await?;
Ok(super::assembly::assemble_from_chain(request_id, &chain))
}
}
impl<P: PoolProvider + Clone + Send + Sync + 'static> FusilladeResponseStore<P> {
pub async fn finalize_head_request(&self, request_id: &str, status_code: u16, body: serde_json::Value) -> Result<(), StoreError> {
let step_manager = self.require_step_manager()?;
let head_step_uuid = parse_response_id(request_id)?;
let head_step = step_manager.get_step(StepId(head_step_uuid)).await.map_err(map_fusillade_err)?;
let Some(head_step) = head_step else {
return Ok(());
};
let Some(sub_request_id) = head_step.request_id else {
return Ok(());
};
let body_str = serde_json::to_string(&body).map_err(|e| StoreError::StorageError(format!("serialize finalized body: {e}")))?;
if status_code == 200 {
self.request_manager
.complete_request(sub_request_id, &body_str, status_code)
.await
.map_err(|e| StoreError::StorageError(format!("complete head sub-request row: {e}")))
} else {
self.request_manager
.fail_request(sub_request_id, &body_str, status_code)
.await
.map_err(|e| StoreError::StorageError(format!("fail head sub-request row: {e}")))
}
}
async fn create_sub_request_row(&self, request_id: &str, descriptor: &StepDescriptor) -> Result<Uuid, StoreError> {
let pending = self.pending_input(request_id)?;
let model = descriptor
.request_payload
.get("model")
.and_then(|m| m.as_str())
.unwrap_or("")
.to_string();
let body = serde_json::to_string(&descriptor.request_payload)
.map_err(|e| StoreError::StorageError(format!("serialize step request_payload: {e}")))?;
let sub_request_id = Uuid::new_v4();
let input = CreateSingleRequestBatchInput {
batch_id: None,
request_id: sub_request_id,
body,
model,
base_url: pending.base_url,
endpoint: "/v1/chat/completions".to_string(),
completion_window: "0s".to_string(),
initial_state: "processing".to_string(),
api_key: pending.api_key,
created_by: pending.created_by,
};
self.request_manager
.create_single_request_batch(input)
.await
.map_err(|e| StoreError::StorageError(format!("create sub-request row: {e}")))?;
Ok(sub_request_id)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_response_id_with_prefix() {
let uuid = Uuid::new_v4();
let id = format!("resp_{uuid}");
let parsed = parse_response_id(&id).unwrap();
assert_eq!(parsed, uuid);
}
#[test]
fn test_parse_response_id_without_prefix() {
let uuid = Uuid::new_v4();
let parsed = parse_response_id(&uuid.to_string()).unwrap();
assert_eq!(parsed, uuid);
}
#[test]
fn test_parse_response_id_invalid() {
let result = parse_response_id("not-a-uuid");
assert!(result.is_err());
assert!(matches!(result, Err(StoreError::NotFound(_))));
}
#[test]
fn test_state_to_status_mapping() {
assert_eq!(state_to_status("pending"), "queued");
assert_eq!(state_to_status("claimed"), "in_progress");
assert_eq!(state_to_status("processing"), "in_progress");
assert_eq!(state_to_status("completed"), "completed");
assert_eq!(state_to_status("failed"), "failed");
assert_eq!(state_to_status("canceled"), "cancelled");
assert_eq!(state_to_status("unknown"), "failed");
}
#[test]
fn test_store_extracts_id_from_response() {
let response = serde_json::json!({
"id": "resp_12345678-1234-1234-1234-123456789abc",
"status": "completed",
});
let id = response.get("id").and_then(|v| v.as_str()).unwrap_or("");
assert_eq!(id, "resp_12345678-1234-1234-1234-123456789abc");
}
#[test]
fn test_store_handles_missing_id() {
let response = serde_json::json!({"status": "completed"});
let id = response.get("id").and_then(|v| v.as_str()).unwrap_or("");
assert_eq!(id, "");
}
#[test]
fn test_extract_upstream_error_openai_format() {
let body = r#"{"error":{"message":"Forbidden","type":"invalid_request_error","param":null,"code":"forbidden"}}"#;
let (error_type, message) = extract_upstream_error(403, body);
assert_eq!(error_type, "permission_error");
assert_eq!(message, "Forbidden");
}
#[test]
fn test_extract_upstream_error_rate_limit() {
let body = r#"{"error":{"message":"Rate limit exceeded","type":"rate_limit_error","code":"rate_limit"}}"#;
let (error_type, message) = extract_upstream_error(429, body);
assert_eq!(error_type, "rate_limit_error");
assert_eq!(message, "Rate limit exceeded");
}
#[test]
fn test_extract_upstream_error_plain_text() {
let (error_type, message) = extract_upstream_error(402, "Account balance too low");
assert_eq!(error_type, "insufficient_credits");
assert_eq!(message, "Account balance too low");
}
#[test]
fn test_extract_upstream_error_server_error() {
let body = r#"{"error":{"message":"Internal error"}}"#;
let (error_type, message) = extract_upstream_error(500, body);
assert_eq!(error_type, "server_error");
assert_eq!(message, "Internal error");
}
#[test]
fn test_parse_failure_error_legacy_format() {
let err = r#"{"type":"NonRetriableHttpStatus","details":{"status":403,"body":"{\"error\":{\"message\":\"Forbidden\",\"type\":\"invalid_request_error\",\"param\":null,\"code\":\"forbidden\"}}"}}"#;
let (error_type, message, status_code) = parse_failure_error(err);
assert_eq!(error_type, "permission_error");
assert_eq!(message, "Forbidden");
assert_eq!(status_code, Some(403));
}
#[test]
fn test_parse_failure_error_plain_string() {
let (error_type, message, status_code) = parse_failure_error("some unknown error");
assert_eq!(error_type, "server_error");
assert_eq!(message, "some unknown error");
assert_eq!(status_code, None);
}
#[test]
fn test_parse_failure_error_legacy_upstream_returned_format() {
let err =
r#"Upstream returned 403: {"error":{"message":"Forbidden","type":"invalid_request_error","param":null,"code":"forbidden"}}"#;
let (error_type, message, status_code) = parse_failure_error(err);
assert_eq!(error_type, "permission_error");
assert_eq!(message, "Forbidden");
assert_eq!(status_code, Some(403));
}
#[test]
fn test_status_to_error_type_mapping() {
assert_eq!(status_to_error_type(400), "invalid_request_error");
assert_eq!(status_to_error_type(401), "authentication_error");
assert_eq!(status_to_error_type(402), "insufficient_credits");
assert_eq!(status_to_error_type(403), "permission_error");
assert_eq!(status_to_error_type(404), "not_found_error");
assert_eq!(status_to_error_type(429), "rate_limit_error");
assert_eq!(status_to_error_type(500), "server_error");
assert_eq!(status_to_error_type(503), "server_error");
}
}