use crate::request::AnyRequest;
use futures::StreamExt;
pub use sqlx_pool_router::{PoolProvider, TestDbPools};
use std::pin::Pin;
use std::sync::Arc;
use anyhow::anyhow;
use async_trait::async_trait;
use chrono::{DateTime, Utc};
use futures::stream::Stream;
use sqlx::QueryBuilder;
use sqlx::Row;
use sqlx::postgres::{PgListener, PgPool};
use std::collections::HashMap;
use tokio::sync::{Mutex, mpsc};
use tokio::task::JoinHandle;
use tokio_stream::wrappers::ReceiverStream;
use uuid::Uuid;
use super::{DaemonStorage, Storage};
use crate::batch::{
Batch, BatchErrorDetails, BatchErrorItem, BatchId, BatchInput, BatchNotification,
BatchOutputItem, BatchResponseDetails, BatchStatus, File, FileContentItem, FileId,
FileMetadata, FileStreamItem, FileStreamResult, ListBatchesFilter, OutputFileType,
RequestTemplateInput, TemplateId,
};
use crate::daemon::{
AnyDaemonRecord, Daemon, DaemonConfig, DaemonData, DaemonRecord, DaemonState, DaemonStatus,
Dead, Initializing, Running,
};
use crate::error::{FusilladeError, Result};
use crate::http::HttpClient;
use crate::request::{
Canceled, Claimed, Completed, DaemonId, Failed, FailureReason, Pending, Processing, Request,
RequestData, RequestId, RequestState,
};
use super::DaemonExecutor;
use super::utils::{
calculate_error_message_size, calculate_response_body_size, estimate_error_file_size,
estimate_output_file_size,
};
#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum BatchInsertStrategy {
Batched { batch_size: usize },
}
impl Default for BatchInsertStrategy {
fn default() -> Self {
BatchInsertStrategy::Batched { batch_size: 5000 }
}
}
pub struct PostgresRequestManager<P: PoolProvider, H: HttpClient> {
pools: P,
http_client: Arc<H>,
config: DaemonConfig,
download_buffer_size: usize,
batch_insert_strategy: BatchInsertStrategy,
}
macro_rules! batch_from_dynamic_row {
($row:expr) => {
Batch {
id: BatchId($row.get("id")),
file_id: $row.get::<Option<Uuid>, _>("file_id").map(FileId),
endpoint: $row.get("endpoint"),
completion_window: $row.get("completion_window"),
metadata: $row.get("metadata"),
output_file_id: $row.get::<Option<Uuid>, _>("output_file_id").map(FileId),
error_file_id: $row.get::<Option<Uuid>, _>("error_file_id").map(FileId),
created_by: $row.get("created_by"),
created_at: $row.get("created_at"),
expires_at: $row.get("expires_at"),
cancelling_at: $row.get("cancelling_at"),
errors: $row.get("errors"),
total_requests: $row.get("total_requests"),
requests_started_at: $row.get("requests_started_at"),
finalizing_at: $row.get("finalizing_at"),
completed_at: $row.get("completed_at"),
failed_at: $row.get("failed_at"),
cancelled_at: $row.get("cancelled_at"),
deleted_at: $row.get("deleted_at"),
pending_requests: $row.get("pending_requests"),
in_progress_requests: $row.get("in_progress_requests"),
completed_requests: $row.get("completed_requests"),
failed_requests: $row.get("failed_requests"),
canceled_requests: $row.get("canceled_requests"),
notification_sent_at: $row.get("notification_sent_at"),
api_key_id: $row.get::<Option<Uuid>, _>("api_key_id"),
}
};
}
macro_rules! batch_status_from_dynamic_row {
($row:expr) => {
BatchStatus {
batch_id: BatchId($row.get("batch_id")),
file_id: $row.get::<Option<Uuid>, _>("file_id").map(FileId),
file_name: $row.get("file_name"),
total_requests: $row.get("total_requests"),
pending_requests: $row.get("pending_requests"),
in_progress_requests: $row.get("in_progress_requests"),
completed_requests: $row.get("completed_requests"),
failed_requests: $row.get("failed_requests"),
canceled_requests: $row.get("canceled_requests"),
started_at: $row.get("started_at"),
failed_at: $row.get("failed_at"),
created_at: $row.get("created_at"),
}
};
}
impl<P: PoolProvider> PostgresRequestManager<P, crate::http::ReqwestHttpClient> {
pub fn new(pools: P, config: DaemonConfig) -> Self {
let http_client = Arc::new(crate::http::ReqwestHttpClient::new(
std::time::Duration::from_millis(config.first_chunk_timeout_ms),
std::time::Duration::from_millis(config.chunk_timeout_ms),
std::time::Duration::from_millis(config.body_timeout_ms),
config.streamable_endpoints.clone(),
));
Self {
pools,
http_client,
config,
download_buffer_size: 100,
batch_insert_strategy: BatchInsertStrategy::default(),
}
}
}
impl<P: PoolProvider, H: HttpClient + 'static> PostgresRequestManager<P, H> {
pub fn with_client(pools: P, http_client: Arc<H>) -> Self {
Self {
pools,
http_client,
config: DaemonConfig::default(),
download_buffer_size: 100,
batch_insert_strategy: BatchInsertStrategy::default(),
}
}
pub fn with_config(mut self, config: DaemonConfig) -> Self {
self.config = config;
self
}
pub fn with_download_buffer_size(mut self, buffer_size: usize) -> Self {
self.download_buffer_size = buffer_size;
self
}
pub fn with_batch_insert_strategy(mut self, strategy: BatchInsertStrategy) -> Self {
match strategy {
BatchInsertStrategy::Batched { batch_size } => {
assert!(
batch_size > 0,
"batch_size must be greater than 0, got {}",
batch_size
);
}
}
self.batch_insert_strategy = strategy;
self
}
pub async fn mark_batch_failed(&self, batch_id: BatchId, error_message: &str) -> Result<()> {
sqlx::query!(
r#"
UPDATE batches
SET failed_at = NOW(),
errors = $2
WHERE id = $1 AND failed_at IS NULL
"#,
*batch_id as Uuid,
serde_json::json!({"message": error_message}),
)
.execute(self.pools.write())
.await
.map_err(|e| FusilladeError::Other(anyhow!("Failed to mark batch as failed: {}", e)))?;
Ok(())
}
pub fn pool(&self) -> &PgPool {
self.pools.write()
}
pub async fn create_listener(&self) -> Result<PgListener> {
PgListener::connect_with(self.pools.write())
.await
.map_err(|e| FusilladeError::Other(anyhow!("Failed to create listener: {}", e)))
}
}
impl<P: PoolProvider, H: HttpClient + 'static> PostgresRequestManager<P, H> {
#[tracing::instrument(skip(self))]
async fn unclaim_stale_requests(&self) -> Result<usize> {
let claim_timeout_ms = self.config.claim_timeout_ms as i64;
let processing_timeout_ms = self.config.processing_timeout_ms as i64;
let stale_daemon_threshold_ms = self.config.stale_daemon_threshold_ms as i64;
let limit = self.config.unclaim_batch_size as i64;
let unclaim_start = std::time::Instant::now();
let result = sqlx::query!(
r#"
UPDATE requests
SET
state = 'pending',
daemon_id = NULL,
claimed_at = NULL,
started_at = NULL
WHERE id IN (
SELECT id FROM (
-- Time-based fallback: request stuck too long regardless of daemon state
SELECT r.id FROM requests r
WHERE
(r.state = 'claimed' AND r.claimed_at < NOW() - ($1 || ' milliseconds')::INTERVAL)
OR
(r.state = 'processing' AND r.started_at < NOW() - ($2 || ' milliseconds')::INTERVAL)
UNION
-- Daemon-aware reclaim: daemon is dead or its heartbeat went stale
SELECT r.id FROM requests r
WHERE
r.state IN ('claimed', 'processing')
AND r.daemon_id IN (
SELECT d.id FROM daemons d
WHERE d.status = 'dead'
OR d.last_heartbeat < NOW() - ($3 || ' milliseconds')::INTERVAL
)
) sub
LIMIT $4
)
"#,
claim_timeout_ms.to_string(),
processing_timeout_ms.to_string(),
stale_daemon_threshold_ms.to_string(),
limit,
)
.execute(self.pools.write())
.await
.map_err(|e| FusilladeError::Other(anyhow!("Failed to unclaim stale requests: {}", e)))?;
metrics::histogram!("fusillade_unclaim_stale_duration_seconds")
.record(unclaim_start.elapsed().as_secs_f64());
let count = result.rows_affected() as usize;
if count > 0 {
metrics::counter!("fusillade_stale_requests_reclaimed_total").increment(count as u64);
tracing::warn!(
count = count,
claim_timeout_ms,
processing_timeout_ms,
stale_daemon_threshold_ms,
"Unclaimed stale requests (likely due to daemon crash or shutdown)"
);
}
Ok(count)
}
async fn check_and_mark_expired(&self, file: &mut File) -> Result<bool> {
if file.status != crate::batch::FileStatus::Processed {
return Ok(false);
}
if let Some(expires_at) = file.expires_at
&& Utc::now() > expires_at
{
sqlx::query!(
r#"
UPDATE files
SET status = 'expired'
WHERE id = $1 AND status = 'processed'
"#,
*file.id as Uuid,
)
.execute(self.pools.write())
.await
.map_err(|e| FusilladeError::Other(anyhow!("Failed to mark file as expired: {}", e)))?;
file.status = crate::batch::FileStatus::Expired;
return Ok(true);
}
Ok(false)
}
fn file_lock_key(file_id: FileId) -> i64 {
file_id.0.as_u128() as u64 as i64
}
fn calculate_virtual_file_size_from_batch(
&self,
batch: &Batch,
file_type: OutputFileType,
raw_size_sum: i64,
) -> Option<i64> {
let request_count = match file_type {
OutputFileType::Output => batch.completed_requests,
OutputFileType::Error => batch.failed_requests,
};
if request_count == 0 {
return Some(0);
}
match file_type {
OutputFileType::Output => estimate_output_file_size(raw_size_sum, request_count, None),
OutputFileType::Error => estimate_error_file_size(raw_size_sum, request_count, None),
}
}
fn calculate_virtual_file_size_from_row(
&self,
row: &sqlx::postgres::PgRow,
purpose: &Option<crate::batch::Purpose>,
size_finalized: bool,
) -> Result<Option<i64>> {
if size_finalized
|| (purpose != &Some(crate::batch::Purpose::BatchOutput)
&& purpose != &Some(crate::batch::Purpose::BatchError))
{
return Ok(None);
}
let raw_size_sum: Option<i64> = row.try_get("calculated_size").ok().flatten();
let raw_sum = raw_size_sum.unwrap_or(0);
let completed: Option<i64> = row.try_get("completed_requests").ok().flatten();
let failed: Option<i64> = row.try_get("failed_requests").ok().flatten();
let request_count = if purpose == &Some(crate::batch::Purpose::BatchOutput) {
completed.unwrap_or(0)
} else {
failed.unwrap_or(0)
};
if request_count == 0 {
return Ok(Some(0));
}
let estimated_size = if purpose == &Some(crate::batch::Purpose::BatchOutput) {
estimate_output_file_size(raw_sum, request_count, None)
} else {
estimate_error_file_size(raw_sum, request_count, None)
};
if estimated_size.is_none() {
tracing::warn!(
"File size estimation overflow for {:?} file with {} requests",
purpose,
request_count
);
}
Ok(estimated_size)
}
fn is_batch_complete(batch: &Batch) -> bool {
let terminal_count =
batch.completed_requests + batch.failed_requests + batch.canceled_requests;
terminal_count == batch.total_requests && batch.total_requests > 0
}
async fn finalize_file_size(
pool: &PgPool,
file_id: FileId,
estimated_size: i64,
) -> Result<bool> {
let lock_key = Self::file_lock_key(file_id);
let lock_acquired = match sqlx::query_scalar!("SELECT pg_try_advisory_lock($1)", lock_key)
.fetch_one(pool)
.await
{
Ok(Some(acquired)) => acquired,
Ok(None) => {
tracing::warn!(
file_id = %file_id,
"Advisory lock query returned NULL unexpectedly"
);
false
}
Err(e) => {
tracing::error!(
file_id = %file_id,
error = %e,
"Database error while trying to acquire advisory lock"
);
return Err(FusilladeError::Other(anyhow!(
"Failed to acquire lock: {}",
e
)));
}
};
if !lock_acquired {
return Ok(false);
}
let result = sqlx::query!(
r#"
UPDATE files
SET size_bytes = $2, size_finalized = TRUE
WHERE id = $1 AND size_finalized = FALSE
"#,
*file_id as Uuid,
estimated_size,
)
.execute(pool)
.await
.map_err(|e| FusilladeError::Other(anyhow!("Failed to update file size: {}", e)));
if let Err(e) = sqlx::query_scalar!("SELECT pg_advisory_unlock($1)", lock_key)
.fetch_one(pool)
.await
{
tracing::warn!(
file_id = %file_id,
error = %e,
"Failed to release advisory lock (will be released on connection return to pool)"
);
}
result?;
Ok(true)
}
fn spawn_finalize_if_complete(
&self,
row: &sqlx::postgres::PgRow,
file_id: FileId,
estimated_size: i64,
) {
let total: Option<i64> = row.try_get("total_requests").ok().flatten();
let completed: Option<i64> = row.try_get("completed_requests").ok().flatten();
let failed: Option<i64> = row.try_get("failed_requests").ok().flatten();
let canceled: Option<i64> = row.try_get("canceled_requests").ok().flatten();
let in_progress: Option<i64> = row.try_get("in_progress_requests").ok().flatten();
if let (Some(total_count), Some(comp), Some(fail), Some(canc), Some(_prog)) =
(total, completed, failed, canceled, in_progress)
{
let terminal_count = comp + fail + canc;
let is_complete = terminal_count == total_count && total_count > 0;
if is_complete {
let pool = self.pools.write().clone();
tokio::spawn(async move {
if let Err(e) = Self::finalize_file_size(&pool, file_id, estimated_size).await {
tracing::warn!("Failed to finalize file size for {}: {}", file_id, e);
}
});
}
}
}
async fn maybe_finalize_file_size(&self, file: &mut File) -> Result<()> {
if file.size_finalized {
return Ok(());
}
let file_type = match file.purpose {
Some(crate::batch::Purpose::BatchOutput) => OutputFileType::Output,
Some(crate::batch::Purpose::BatchError) => OutputFileType::Error,
_ => return Ok(()),
};
let batch = match self.get_batch_by_output_file_id(file.id, file_type).await? {
Some(b) => b,
None => return Ok(()),
};
let state_filter = match file_type {
OutputFileType::Output => "completed",
OutputFileType::Error => "failed",
};
let raw_size_sum = sqlx::query_scalar!(
r#"
SELECT COALESCE(SUM(response_size), 0)::BIGINT as "sum!"
FROM requests
WHERE batch_id = $1
AND state = $2
"#,
*batch.id as Uuid,
state_filter,
)
.fetch_one(self.pools.write())
.await
.map_err(|e| FusilladeError::Other(anyhow!("Failed to calculate file size: {}", e)))?;
let estimated_size = self
.calculate_virtual_file_size_from_batch(&batch, file_type, raw_size_sum)
.unwrap_or(0);
file.size_bytes = estimated_size;
if !Self::is_batch_complete(&batch) {
return Ok(());
}
let finalized =
Self::finalize_file_size(self.pools.write(), file.id, estimated_size).await?;
if finalized {
file.size_finalized = true;
}
Ok(())
}
async fn get_file_from_pool(&self, file_id: FileId, pool: &PgPool) -> Result<File> {
let row = sqlx::query!(
r#"
SELECT id, name, description, size_bytes, size_finalized, status, error_message, purpose, expires_at, deleted_at, uploaded_by, created_at, updated_at, api_key_id
FROM files
WHERE id = $1 AND deleted_at IS NULL
"#,
*file_id as Uuid,
)
.fetch_optional(pool)
.await
.map_err(|e| FusilladeError::Other(anyhow!("Failed to fetch file: {}", e)))?
.ok_or_else(|| FusilladeError::Other(anyhow!("File not found")))?;
let status = row
.status
.parse::<crate::batch::FileStatus>()
.map_err(|e| {
FusilladeError::Other(anyhow!("Invalid file status '{}': {}", row.status, e))
})?;
let purpose = row
.purpose
.map(|s| s.parse::<crate::batch::Purpose>())
.transpose()
.map_err(|e| FusilladeError::Other(anyhow!("Invalid purpose: {}", e)))?;
Ok(File {
id: FileId(row.id),
name: row.name,
description: row.description,
size_bytes: row.size_bytes,
size_finalized: row.size_finalized,
status,
error_message: row.error_message,
purpose,
expires_at: row.expires_at,
deleted_at: row.deleted_at,
uploaded_by: row.uploaded_by,
created_at: row.created_at,
updated_at: row.updated_at,
api_key_id: row.api_key_id,
})
}
}
#[async_trait]
impl<P: PoolProvider, H: HttpClient + 'static> Storage for PostgresRequestManager<P, H> {
async fn get_pending_request_counts_by_model_and_completion_window(
&self,
windows: &[(String, i64)], states: &[String], model_filter: &[String], strict: bool,
) -> Result<HashMap<String, HashMap<String, i64>>> {
if windows.is_empty() || states.is_empty() {
return Ok(HashMap::new());
}
let (labels, seconds): (Vec<String>, Vec<i64>) = windows.iter().cloned().unzip();
let pool = if strict {
self.pools.write()
} else {
self.pools.read()
};
let rows = sqlx::query(
r#"
WITH windows(label, window_seconds) AS (
SELECT * FROM UNNEST($1::text[], $2::bigint[])
)
SELECT
r.model as model,
w.label as completion_window,
COUNT(*) FILTER (
WHERE b.expires_at <= NOW() + make_interval(secs => w.window_seconds)
)::BIGINT as count
FROM requests r
JOIN batches b ON r.batch_id = b.id
CROSS JOIN windows w
WHERE r.state = ANY($3)
AND r.template_id IS NOT NULL
AND b.cancelling_at IS NULL
AND (cardinality($4::text[]) = 0 OR r.model = ANY($4))
GROUP BY r.model, w.label
"#,
)
.bind(&labels)
.bind(&seconds)
.bind(states)
.bind(model_filter)
.fetch_all(pool)
.await
.map_err(|e| {
FusilladeError::Other(anyhow!(
"Failed to get pending request counts by model and completion window: {}",
e
))
})?;
let mut result: HashMap<String, HashMap<String, i64>> = HashMap::new();
for row in rows {
let model: String = row
.try_get("model")
.map_err(|e| FusilladeError::Other(anyhow!("Failed to read model: {}", e)))?;
let completion_window: String = row.try_get("completion_window").map_err(|e| {
FusilladeError::Other(anyhow!("Failed to read completion_window: {}", e))
})?;
let count: i64 = row
.try_get("count")
.map_err(|e| FusilladeError::Other(anyhow!("Failed to read count: {}", e)))?;
result
.entry(model)
.or_default()
.insert(completion_window, count);
}
Ok(result)
}
#[tracing::instrument(skip(self, available_capacity, user_active_counts), fields(limit))]
async fn claim_requests(
&self,
limit: usize,
daemon_id: DaemonId,
available_capacity: &std::collections::HashMap<String, usize>,
user_active_counts: &std::collections::HashMap<String, usize>,
) -> Result<Vec<Request<Claimed>>> {
let unclaimed_count = self.unclaim_stale_requests().await?;
if unclaimed_count > 0 {
tracing::info!(
unclaimed_count,
"Unclaimed stale requests before claiming new ones"
);
}
let now = Utc::now();
let mut model_capacity_pairs: Vec<(String, i64)> = available_capacity
.iter()
.filter(|(_, cap)| **cap > 0)
.map(|(model, cap)| (model.clone(), *cap as i64))
.collect();
{
use rand::seq::SliceRandom;
let mut rng = rand::rng();
model_capacity_pairs.shuffle(&mut rng);
}
let models_arr: Vec<String> = model_capacity_pairs
.iter()
.map(|(m, _)| m.clone())
.collect();
let capacities_arr: Vec<i64> = model_capacity_pairs.iter().map(|(_, c)| *c).collect();
tracing::debug!(
model_count = models_arr.len(),
"Claiming for models with available capacity"
);
if models_arr.is_empty() {
tracing::debug!("No models with available capacity, skipping claim");
return Ok(Vec::new());
}
let user_ids_arr: Vec<String> = user_active_counts.keys().cloned().collect();
let user_counts_arr: Vec<i64> = user_ids_arr
.iter()
.map(|u| *user_active_counts.get(u).unwrap_or(&0) as i64)
.collect();
let rows = sqlx::query!(
r#"
WITH active_batch_ids AS MATERIALIZED (
SELECT b.id, b.expires_at, b.created_by
FROM batches b
WHERE b.cancelling_at IS NULL
AND b.deleted_at IS NULL
AND b.completed_at IS NULL
AND b.failed_at IS NULL
AND b.cancelled_at IS NULL
AND EXISTS (
SELECT 1 FROM requests r
WHERE r.batch_id = b.id
AND r.state = 'pending'
)
),
user_priority AS (
SELECT * FROM unnest($6::TEXT[], $7::BIGINT[]) AS u(user_id, active_count)
),
to_claim AS (
SELECT claimed.id, claimed.template_id, claimed.batch_id
FROM unnest($4::TEXT[], $5::BIGINT[]) AS m(model, capacity)
CROSS JOIN LATERAL (
SELECT r2.id, r2.template_id, r2.batch_id
FROM active_batch_ids ab
LEFT JOIN user_priority up ON ab.created_by = up.user_id
CROSS JOIN LATERAL (
SELECT r3.id, r3.template_id, r3.batch_id
FROM requests r3
WHERE r3.state = 'pending'
AND r3.model = m.model
AND r3.template_id IS NOT NULL
AND r3.batch_id = ab.id
AND (r3.not_before IS NULL OR r3.not_before <= $3)
LIMIT m.capacity
FOR UPDATE OF r3 SKIP LOCKED
) r2
ORDER BY
(1.0 - $8::DOUBLE PRECISION)
* COALESCE(up.active_count, 0)::DOUBLE PRECISION
/ GREATEST(NULLIF((SELECT MAX(v) FROM unnest($7::BIGINT[]) v), 0), 1)::DOUBLE PRECISION
+ $8::DOUBLE PRECISION
* LEAST(GREATEST(EXTRACT(EPOCH FROM ab.expires_at - $3), 0.0) / 86400.0, 1.0)
ASC,
ab.expires_at ASC,
ab.id ASC
LIMIT m.capacity
) claimed
LIMIT $2::BIGINT
)
UPDATE requests r
SET
state = 'claimed',
daemon_id = $1,
claimed_at = $3
FROM to_claim tc
JOIN active_request_templates t ON tc.template_id = t.id
JOIN batches b ON tc.batch_id = b.id
WHERE r.id = tc.id
RETURNING r.id, r.batch_id as "batch_id!", r.template_id as "template_id!", r.retry_attempt,
t.custom_id, t.endpoint as "endpoint!", t.method as "method!", t.path as "path!",
t.body as "body!", t.model as "model!", COALESCE(b.api_key, t.api_key) as "api_key!",
b.expires_at as batch_expires_at,
b.id::TEXT as "batch_id_str!",
b.file_id::TEXT as "batch_file_id!",
b.endpoint as "batch_endpoint!",
b.completion_window as "batch_completion_window!",
b.metadata::TEXT as "batch_metadata",
b.output_file_id::TEXT as "batch_output_file_id",
b.error_file_id::TEXT as "batch_error_file_id",
b.created_by as "batch_created_by!",
to_char(b.created_at AT TIME ZONE 'UTC', 'YYYY-MM-DD"T"HH24:MI:SS.US"Z"') as "batch_created_at!",
to_char(b.expires_at AT TIME ZONE 'UTC', 'YYYY-MM-DD"T"HH24:MI:SS.US"Z"') as "batch_expires_at_str",
to_char(b.cancelling_at AT TIME ZONE 'UTC', 'YYYY-MM-DD"T"HH24:MI:SS.US"Z"') as "batch_cancelling_at",
b.errors::TEXT as "batch_errors",
b.total_requests::TEXT as "batch_total_requests!"
"#,
*daemon_id as Uuid,
limit as i64,
now,
&models_arr,
&capacities_arr,
&user_ids_arr,
&user_counts_arr,
self.config.urgency_weight,
)
.fetch_all(self.pools.write())
.await
.map_err(|e| {
FusilladeError::Other(anyhow!(
"Failed to claim requests: {}",
e
))
})?;
let mut all_claimed = Vec::new();
let claimed_count = rows.len();
if claimed_count > 0 {
tracing::debug!(
claimed = claimed_count,
"Claimed requests across all models"
);
let mut parsed_metadata_cache: std::collections::HashMap<
Uuid,
Option<serde_json::Value>,
> = std::collections::HashMap::new();
all_claimed.extend(rows.into_iter().map(|row| {
let mut batch_metadata = std::collections::HashMap::new();
let parsed_metadata =
parsed_metadata_cache
.entry(row.batch_id)
.or_insert_with(|| {
row.batch_metadata
.as_deref()
.and_then(|s| serde_json::from_str(s).ok())
});
for field_name in &self.config.batch_metadata_fields {
let value: Option<&str> = match field_name.as_str() {
"id" => Some(&row.batch_id_str),
"file_id" => Some(&row.batch_file_id),
"endpoint" => Some(&row.batch_endpoint),
"completion_window" => Some(&row.batch_completion_window),
"metadata" => row.batch_metadata.as_deref(),
"output_file_id" => row.batch_output_file_id.as_deref(),
"error_file_id" => row.batch_error_file_id.as_deref(),
"created_by" => Some(&row.batch_created_by),
"created_at" => Some(&row.batch_created_at),
"expires_at" => row.batch_expires_at_str.as_deref(),
"cancelling_at" => row.batch_cancelling_at.as_deref(),
"errors" => row.batch_errors.as_deref(),
"total_requests" => Some(&row.batch_total_requests),
_ => None,
};
if let Some(v) = value {
batch_metadata.insert(field_name.clone(), v.to_string());
} else if let Some(metadata_json) = parsed_metadata.as_ref() {
if let Some(v) = metadata_json.get(field_name).and_then(|v| v.as_str()) {
batch_metadata.insert(field_name.clone(), v.to_string());
}
}
}
Request {
state: Claimed {
daemon_id,
claimed_at: now,
retry_attempt: row.retry_attempt as u32,
batch_expires_at: row.batch_expires_at,
},
data: RequestData {
id: RequestId(row.id),
batch_id: BatchId(row.batch_id),
template_id: TemplateId(row.template_id),
custom_id: row.custom_id,
endpoint: row.endpoint,
method: row.method,
path: row.path,
body: row.body,
model: row.model,
api_key: row.api_key,
created_by: row.batch_created_by.clone(),
batch_metadata,
},
}
}));
}
tracing::debug!(
total_claimed = all_claimed.len(),
"Finished claiming requests across all models"
);
Ok(all_claimed)
}
async fn persist<T: RequestState + Clone>(
&self,
request: &Request<T>,
) -> Result<Option<RequestId>>
where
AnyRequest: From<Request<T>>,
{
const MAX_ATTEMPTS: u32 = 3;
for attempt in 0..MAX_ATTEMPTS {
tracing::debug!(request_id = %request.data.id, "Persisting request state");
let any_request = AnyRequest::from(request.clone());
let result: Result<Option<RequestId>> = async {
match any_request {
AnyRequest::Pending(req) => {
let rows_affected = sqlx::query!(
r#"
UPDATE requests SET
state = 'pending',
retry_attempt = $2,
not_before = $3,
daemon_id = NULL,
claimed_at = NULL,
started_at = NULL
WHERE id = $1
"#,
*req.data.id as Uuid,
req.state.retry_attempt as i32,
req.state.not_before,
)
.execute(self.pools.write())
.await
.map_err(|e| {
FusilladeError::Other(anyhow!("Failed to update request: {}", e))
})?
.rows_affected();
if rows_affected == 0 {
return Err(FusilladeError::RequestNotFound(req.data.id));
}
}
AnyRequest::Claimed(req) => {
let rows_affected = sqlx::query!(
r#"
UPDATE requests SET
state = 'claimed',
retry_attempt = $2,
daemon_id = $3,
claimed_at = $4,
started_at = NULL,
not_before = NULL
WHERE id = $1
"#,
*req.data.id as Uuid,
req.state.retry_attempt as i32,
*req.state.daemon_id as Uuid,
req.state.claimed_at,
)
.execute(self.pools.write())
.await
.map_err(|e| {
FusilladeError::Other(anyhow!("Failed to update request: {}", e))
})?
.rows_affected();
if rows_affected == 0 {
return Err(FusilladeError::RequestNotFound(req.data.id));
}
}
AnyRequest::Processing(req) => {
let rows_affected = sqlx::query!(
r#"
UPDATE requests SET
state = 'processing',
retry_attempt = $2,
daemon_id = $3,
claimed_at = $4,
started_at = $5
WHERE id = $1
"#,
*req.data.id as Uuid,
req.state.retry_attempt as i32,
*req.state.daemon_id as Uuid,
req.state.claimed_at,
req.state.started_at,
)
.execute(self.pools.write())
.await
.map_err(|e| {
FusilladeError::Other(anyhow!("Failed to update request: {}", e))
})?
.rows_affected();
if rows_affected == 0 {
return Err(FusilladeError::RequestNotFound(req.data.id));
}
}
AnyRequest::Completed(req) => {
let response_size = calculate_response_body_size(&req.state.response_body)
.ok_or_else(|| {
FusilladeError::Other(anyhow!("Response body too large"))
})?;
let rows_affected = sqlx::query!(
r#"
UPDATE requests SET
state = 'completed',
response_status = $2,
response_body = $3,
claimed_at = $4,
started_at = $5,
completed_at = $6,
response_size = $7,
routed_model = $8
WHERE id = $1
"#,
*req.data.id as Uuid,
req.state.response_status as i16,
req.state.response_body,
req.state.claimed_at,
req.state.started_at,
req.state.completed_at,
response_size,
req.state.routed_model,
)
.execute(self.pools.write())
.await
.map_err(|e| {
FusilladeError::Other(anyhow!("Failed to update request: {}", e))
})?
.rows_affected();
if rows_affected == 0 {
return Err(FusilladeError::RequestNotFound(req.data.id));
}
}
AnyRequest::Failed(req) => {
let error_json = serde_json::to_string(&req.state.reason).map_err(|e| {
FusilladeError::Other(anyhow!(
"Failed to serialize failure reason: {}",
e
))
})?;
let response_size =
calculate_error_message_size(&error_json).ok_or_else(|| {
FusilladeError::Other(anyhow!("Error message too large"))
})?;
let rows_affected = sqlx::query!(
r#"
UPDATE requests SET
state = 'failed',
retry_attempt = $2,
error = $3,
failed_at = $4,
response_size = $5,
routed_model = $6
WHERE id = $1
"#,
*req.data.id as Uuid,
req.state.retry_attempt as i32,
error_json,
req.state.failed_at,
response_size,
req.state.routed_model,
)
.execute(self.pools.write())
.await
.map_err(|e| {
FusilladeError::Other(anyhow!("Failed to update request: {}", e))
})?
.rows_affected();
if rows_affected == 0 {
return Err(FusilladeError::RequestNotFound(req.data.id));
}
}
AnyRequest::Canceled(req) => {
let rows_affected = sqlx::query!(
r#"
UPDATE requests SET
state = 'canceled',
canceled_at = $2
WHERE id = $1
"#,
*req.data.id as Uuid,
req.state.canceled_at,
)
.execute(self.pools.write())
.await
.map_err(|e| {
FusilladeError::Other(anyhow!("Failed to update request: {}", e))
})?
.rows_affected();
if rows_affected == 0 {
return Err(FusilladeError::RequestNotFound(req.data.id));
}
}
}
Ok(None)
}
.await;
match result {
Ok(val) => return Ok(val),
Err(FusilladeError::RequestNotFound(id)) => {
return Err(FusilladeError::RequestNotFound(id));
}
Err(e) if attempt < MAX_ATTEMPTS - 1 => {
tracing::warn!(
request_id = %request.data.id,
persist_attempt = attempt + 1,
error = %e,
"Failed to persist request state, retrying"
);
tokio::time::sleep(std::time::Duration::from_millis(100 * 2u64.pow(attempt)))
.await;
}
Err(e) => return Err(e),
}
}
Err(FusilladeError::Other(anyhow!(
"Failed to persist request state after {} attempts",
MAX_ATTEMPTS
)))
}
#[tracing::instrument(skip(self, ids), fields(count = ids.len()))]
async fn get_requests(&self, ids: Vec<RequestId>) -> Result<Vec<Result<AnyRequest>>> {
let uuid_ids: Vec<Uuid> = ids.iter().map(|id| **id).collect();
let rows = sqlx::query!(
r#"
SELECT
r.id, r.batch_id as "batch_id!", r.template_id as "template_id?", r.state,
t.custom_id as "custom_id?", t.endpoint as "endpoint?", t.method as "method?",
t.path as "path?", t.body as "body?", t.model as "model?", t.api_key as "api_key?",
r.retry_attempt, r.not_before, r.daemon_id, r.claimed_at, r.started_at,
r.response_status, r.response_body, r.completed_at, r.error, r.failed_at, r.canceled_at,
b.expires_at as batch_expires_at, r.routed_model
FROM requests r
LEFT JOIN active_request_templates t ON r.template_id = t.id
JOIN batches b ON r.batch_id = b.id
WHERE r.id = ANY($1)
"#,
&uuid_ids,
)
.fetch_all(self.pools.read())
.await
.map_err(|e| FusilladeError::Other(anyhow!("Failed to fetch requests: {}", e)))?;
let mut request_map: std::collections::HashMap<RequestId, Result<AnyRequest>> =
std::collections::HashMap::new();
for row in rows {
let request_id = RequestId(row.id);
let data = match (
row.template_id,
row.endpoint,
row.method,
row.path,
row.body,
row.model,
row.api_key,
) {
(
Some(template_id),
Some(endpoint),
Some(method),
Some(path),
Some(body),
Some(model),
Some(api_key),
) => RequestData {
id: request_id,
batch_id: BatchId(row.batch_id),
template_id: TemplateId(template_id),
custom_id: row.custom_id,
endpoint,
method,
path,
body,
model,
api_key,
created_by: String::new(),
batch_metadata: std::collections::HashMap::new(),
},
_ => {
request_map.insert(
request_id,
Err(FusilladeError::Other(anyhow!(
"Request template has been deleted"
))),
);
continue;
}
};
let state = &row.state;
let any_request = match state.as_str() {
"pending" => Ok(AnyRequest::Pending(Request {
state: Pending {
retry_attempt: row.retry_attempt as u32,
not_before: row.not_before,
batch_expires_at: row.batch_expires_at,
},
data,
})),
"claimed" => Ok(AnyRequest::Claimed(Request {
state: Claimed {
daemon_id: DaemonId(row.daemon_id.ok_or_else(|| {
FusilladeError::Other(anyhow!("Missing daemon_id for claimed request"))
})?),
claimed_at: row.claimed_at.ok_or_else(|| {
FusilladeError::Other(anyhow!("Missing claimed_at for claimed request"))
})?,
retry_attempt: row.retry_attempt as u32,
batch_expires_at: row.batch_expires_at,
},
data,
})),
"processing" => {
let (_tx, rx) = tokio::sync::mpsc::channel(1);
let abort_handle = tokio::spawn(async {}).abort_handle();
Ok(AnyRequest::Processing(Request {
state: Processing {
daemon_id: DaemonId(row.daemon_id.ok_or_else(|| {
FusilladeError::Other(anyhow!(
"Missing daemon_id for processing request"
))
})?),
claimed_at: row.claimed_at.ok_or_else(|| {
FusilladeError::Other(anyhow!(
"Missing claimed_at for processing request"
))
})?,
started_at: row.started_at.ok_or_else(|| {
FusilladeError::Other(anyhow!(
"Missing started_at for processing request"
))
})?,
retry_attempt: row.retry_attempt as u32,
batch_expires_at: row.batch_expires_at,
result_rx: Arc::new(Mutex::new(rx)),
abort_handle,
},
data,
}))
}
"completed" => Ok(AnyRequest::Completed(Request {
state: Completed {
response_status: row.response_status.ok_or_else(|| {
FusilladeError::Other(anyhow!(
"Missing response_status for completed request"
))
})? as u16,
response_body: row.response_body.ok_or_else(|| {
FusilladeError::Other(anyhow!(
"Missing response_body for completed request"
))
})?,
claimed_at: row.claimed_at.ok_or_else(|| {
FusilladeError::Other(anyhow!(
"Missing claimed_at for completed request"
))
})?,
started_at: row.started_at.ok_or_else(|| {
FusilladeError::Other(anyhow!(
"Missing started_at for completed request"
))
})?,
completed_at: row.completed_at.ok_or_else(|| {
FusilladeError::Other(anyhow!(
"Missing completed_at for completed request"
))
})?,
routed_model: row.routed_model.unwrap_or_else(|| data.model.clone()),
},
data,
})),
"failed" => {
let error_json = row.error.ok_or_else(|| {
FusilladeError::Other(anyhow!("Missing error for failed request"))
})?;
let reason: FailureReason =
serde_json::from_str(&error_json).unwrap_or_else(|_| {
FailureReason::NetworkError {
error: error_json.clone(),
}
});
Ok(AnyRequest::Failed(Request {
state: Failed {
reason,
failed_at: row.failed_at.ok_or_else(|| {
FusilladeError::Other(anyhow!(
"Missing failed_at for failed request"
))
})?,
retry_attempt: row.retry_attempt as u32,
batch_expires_at: row.batch_expires_at,
routed_model: row.routed_model.unwrap_or_else(|| data.model.clone()),
},
data,
}))
}
"canceled" => Ok(AnyRequest::Canceled(Request {
state: Canceled {
canceled_at: row.canceled_at.ok_or_else(|| {
FusilladeError::Other(anyhow!(
"Missing canceled_at for canceled request"
))
})?,
},
data,
})),
_ => Err(FusilladeError::Other(anyhow!("Unknown state: {}", state))),
};
request_map.insert(request_id, any_request);
}
Ok(ids
.into_iter()
.map(|id| {
request_map
.remove(&id)
.unwrap_or_else(|| Err(FusilladeError::RequestNotFound(id)))
})
.collect())
}
#[tracing::instrument(skip(self, templates), fields(name = %name, template_count = templates.len()))]
async fn create_file(
&self,
name: String,
description: Option<String>,
templates: Vec<RequestTemplateInput>,
) -> Result<FileId> {
use futures::stream;
let mut items = vec![FileStreamItem::Metadata(FileMetadata {
filename: Some(name),
description,
..Default::default()
})];
for template in templates {
items.push(FileStreamItem::Template(template));
}
let stream = stream::iter(items);
match self.create_file_stream(stream).await? {
FileStreamResult::Success(file_id) => Ok(file_id),
FileStreamResult::Aborted => Err(FusilladeError::Other(anyhow!(
"create_file produced an aborted stream result for an internally constructed stream"
))),
}
}
#[tracing::instrument(skip(self, stream))]
async fn create_file_stream<S: Stream<Item = FileStreamItem> + Send + Unpin>(
&self,
mut stream: S,
) -> Result<FileStreamResult> {
use futures::StreamExt;
let mut tx =
self.pools.write().begin().await.map_err(|e| {
FusilladeError::Other(anyhow!("Failed to begin transaction: {}", e))
})?;
let mut metadata = FileMetadata::default();
let mut file_id: Option<Uuid> = None;
let mut template_count = 0;
let batch_size = match self.batch_insert_strategy {
BatchInsertStrategy::Batched { batch_size } => batch_size,
};
let mut template_buffer = Vec::with_capacity(batch_size);
while let Some(item) = stream.next().await {
match item {
FileStreamItem::Metadata(meta) => {
if meta.filename.is_some() {
metadata.filename = meta.filename;
}
if meta.description.is_some() {
metadata.description = meta.description;
}
if meta.purpose.is_some() {
metadata.purpose = meta.purpose;
}
if meta.expires_after_anchor.is_some() {
metadata.expires_after_anchor = meta.expires_after_anchor;
}
if meta.expires_after_seconds.is_some() {
metadata.expires_after_seconds = meta.expires_after_seconds;
}
if meta.size_bytes.is_some() {
metadata.size_bytes = meta.size_bytes;
}
if meta.uploaded_by.is_some() {
metadata.uploaded_by = meta.uploaded_by;
}
if meta.api_key_id.is_some() {
metadata.api_key_id = meta.api_key_id;
}
}
FileStreamItem::Template(template) => {
if file_id.is_none() {
let new_id = Uuid::new_v4();
let stub_name = metadata
.filename
.clone()
.unwrap_or_else(|| format!("upload-{}", new_id));
let status = crate::batch::FileStatus::Processed.to_string();
sqlx::query!(
r#"
INSERT INTO files (id, name, status, created_at, updated_at)
VALUES ($1, $2, $3, NOW(), NOW())
"#,
new_id,
stub_name,
status,
)
.execute(&mut *tx)
.await
.map_err(|e| {
FusilladeError::Other(anyhow!("Failed to create file stub: {}", e))
})?;
file_id = Some(new_id);
}
template_buffer.push((template, template_count));
template_count += 1;
if template_buffer.len() >= batch_size {
Self::insert_template_batch(&mut tx, file_id.unwrap(), &template_buffer)
.await?;
template_buffer.clear();
}
}
FileStreamItem::Abort => {
tx.rollback().await.map_err(|e| {
FusilladeError::Other(anyhow!(
"Failed to roll back aborted file stream transaction: {}",
e
))
})?;
return Ok(FileStreamResult::Aborted);
}
#[allow(deprecated)]
FileStreamItem::Error(err) => {
tracing::warn!("FileStreamItem::Error is deprecated; use Abort instead");
return Err(FusilladeError::ValidationError(err));
}
}
}
if !template_buffer.is_empty() {
Self::insert_template_batch(&mut tx, file_id.unwrap(), &template_buffer).await?;
}
let fid = if let Some(id) = file_id {
id
} else {
let new_id = Uuid::new_v4();
let stub_name = metadata
.filename
.clone()
.unwrap_or_else(|| format!("upload-{}", new_id));
let status = crate::batch::FileStatus::Processed.to_string();
sqlx::query!(
r#"
INSERT INTO files (id, name, status, created_at, updated_at)
VALUES ($1, $2, $3, NOW(), NOW())
"#,
new_id,
stub_name,
status,
)
.execute(&mut *tx)
.await
.map_err(|e| FusilladeError::Other(anyhow!("Failed to create file: {}", e)))?;
new_id
};
let size_bytes = metadata.size_bytes.unwrap_or(0);
let status = crate::batch::FileStatus::Processed.to_string();
let purpose = metadata.purpose.clone();
let expires_at = if let Some(seconds) = metadata.expires_after_seconds {
let anchor = if let Some(anchor_str) = metadata.expires_after_anchor.as_ref() {
match anchor_str.as_str() {
"created_at" => Utc::now(), _ => {
tracing::warn!(
anchor = anchor_str,
"Unknown expires_after_anchor value, defaulting to 'created_at'"
);
Utc::now()
}
}
} else {
Utc::now()
};
anchor.checked_add_signed(chrono::Duration::seconds(seconds))
} else {
Utc::now().checked_add_signed(chrono::Duration::days(30))
};
let description = metadata.description.clone();
let uploaded_by = metadata.uploaded_by.clone();
let name = metadata.filename.clone();
sqlx::query!(
r#"
UPDATE files
SET name = COALESCE($2, name),
description = $3,
size_bytes = $4,
status = $5,
purpose = $6,
expires_at = $7,
uploaded_by = $8,
api_key_id = $9,
size_finalized = TRUE,
updated_at = NOW()
WHERE id = $1
"#,
fid,
name,
description,
size_bytes,
status,
purpose,
expires_at,
uploaded_by,
metadata.api_key_id,
)
.execute(&mut *tx)
.await
.map_err(|e| FusilladeError::Other(anyhow!("Failed to update file metadata: {}", e)))?;
tx.commit()
.await
.map_err(|e| FusilladeError::Other(anyhow!("Failed to commit transaction: {}", e)))?;
tracing::debug!(
file_id = %fid,
template_count = template_count,
strategy = ?self.batch_insert_strategy,
"Successfully created file with templates"
);
Ok(FileStreamResult::Success(FileId(fid)))
}
#[tracing::instrument(skip(self), fields(file_id = %file_id))]
async fn get_file(&self, file_id: FileId) -> Result<File> {
let mut file = self.get_file_from_pool(file_id, self.pools.read()).await?;
self.check_and_mark_expired(&mut file).await?;
self.maybe_finalize_file_size(&mut file).await?;
Ok(file)
}
#[tracing::instrument(skip(self), fields(file_id = %file_id))]
async fn get_file_from_primary_pool(&self, file_id: FileId) -> Result<File> {
let mut file = self.get_file_from_pool(file_id, self.pools.write()).await?;
self.check_and_mark_expired(&mut file).await?;
self.maybe_finalize_file_size(&mut file).await?;
Ok(file)
}
async fn get_file_content(&self, file_id: FileId) -> Result<Vec<FileContentItem>> {
let mut stream = self.get_file_content_stream(file_id, 0, None);
let mut items = Vec::new();
while let Some(result) = stream.next().await {
items.push(result?);
}
Ok(items)
}
#[tracing::instrument(skip(self), fields(file_id = %file_id))]
async fn get_file_template_stats(
&self,
file_id: FileId,
) -> Result<Vec<crate::batch::ModelTemplateStats>> {
let stats = sqlx::query!(
r#"
SELECT
model,
COUNT(*)::BIGINT as "request_count!",
SUM(body_byte_size)::BIGINT as "total_body_bytes!"
FROM request_templates
WHERE file_id = $1
GROUP BY model
ORDER BY model
"#,
*file_id as Uuid,
)
.fetch_all(self.pools.read())
.await
.map_err(|e| FusilladeError::Other(anyhow!("Failed to fetch template stats: {}", e)))?;
Ok(stats
.into_iter()
.map(|row| crate::batch::ModelTemplateStats {
model: row.model,
request_count: row.request_count,
total_body_bytes: row.total_body_bytes,
})
.collect())
}
#[tracing::instrument(skip(self), fields(file_id = %file_id, search = ?search))]
fn get_file_content_stream(
&self,
file_id: FileId,
offset: usize,
search: Option<String>,
) -> Pin<Box<dyn Stream<Item = Result<FileContentItem>> + Send>> {
let pool = self.pools.read().clone();
let (tx, rx) = mpsc::channel(self.download_buffer_size);
let offset = offset as i64;
tokio::spawn(async move {
let file_result = sqlx::query!(
r#"
SELECT purpose
FROM files
WHERE id = $1 AND deleted_at IS NULL
"#,
*file_id as Uuid,
)
.fetch_one(&pool)
.await;
let purpose = match file_result {
Ok(row) => row.purpose,
Err(e) => {
let _ = tx
.send(Err(FusilladeError::Other(anyhow!(
"Failed to fetch file: {}",
e
))))
.await;
return;
}
};
match purpose.as_deref() {
Some("batch_output") => {
Self::stream_batch_output(pool, file_id, offset, search, tx).await;
}
Some("batch_error") => {
Self::stream_batch_error(pool, file_id, offset, search, tx).await;
}
_ => {
Self::stream_request_templates(pool, file_id, offset, search, tx).await;
}
}
});
Box::pin(ReceiverStream::new(rx))
}
#[tracing::instrument(skip(self, filter), fields(uploaded_by = ?filter.uploaded_by, status = ?filter.status, purpose = ?filter.purpose, after = ?filter.after, limit = ?filter.limit))]
async fn list_files(&self, filter: crate::batch::FileFilter) -> Result<Vec<File>> {
use sqlx::QueryBuilder;
let after_created_at = if let Some(after_id) = &filter.after {
sqlx::query!(
r#"SELECT created_at FROM files WHERE id = $1"#,
**after_id as Uuid
)
.fetch_optional(self.pools.read())
.await
.map_err(|e| FusilladeError::Other(anyhow!("Failed to fetch after cursor: {}", e)))?
.map(|row| row.created_at)
} else {
None
};
let mut query_builder = QueryBuilder::new(
r#"
SELECT
f.id, f.name, f.description, f.size_bytes, f.size_finalized,
f.status, f.error_message, f.purpose, f.expires_at, f.deleted_at,
f.uploaded_by, f.created_at, f.updated_at, f.api_key_id,
b.id as batch_id,
b.total_requests,
COALESCE(counts.completed, 0)::BIGINT as completed_requests,
COALESCE(counts.failed, 0)::BIGINT as failed_requests,
COALESCE(counts.canceled, 0)::BIGINT as canceled_requests,
COALESCE(counts.in_progress, 0)::BIGINT as in_progress_requests,
size_calc.calculated_size
FROM files f
LEFT JOIN batches b ON (
(f.purpose = 'batch_output' AND b.output_file_id = f.id) OR
(f.purpose = 'batch_error' AND b.error_file_id = f.id)
)
LEFT JOIN LATERAL (
SELECT
COUNT(*) FILTER (WHERE r.state = 'completed') as completed,
COUNT(*) FILTER (WHERE r.state = 'failed') as failed,
COUNT(*) FILTER (WHERE r.state = 'canceled' OR (r.state IN ('pending', 'claimed', 'processing') AND b.cancelling_at IS NOT NULL)) as canceled,
COUNT(*) FILTER (WHERE r.state IN ('claimed', 'processing') AND b.cancelling_at IS NULL) as in_progress
FROM requests r
WHERE r.batch_id = b.id
) counts ON (f.purpose IN ('batch_output', 'batch_error'))
LEFT JOIN LATERAL (
SELECT SUM(r.response_size)::BIGINT as calculated_size
FROM requests r
WHERE r.batch_id = b.id
AND ((f.purpose = 'batch_output' AND r.state = 'completed') OR
(f.purpose = 'batch_error' AND r.state = 'failed'))
) size_calc ON (f.purpose IN ('batch_output', 'batch_error') AND f.size_finalized = FALSE AND b.id IS NOT NULL)
"#,
);
query_builder.push(" WHERE f.deleted_at IS NULL");
if let Some(uploaded_by) = &filter.uploaded_by {
query_builder.push(" AND f.uploaded_by = ");
query_builder.push_bind(uploaded_by);
}
if let Some(status) = &filter.status {
query_builder.push(" AND f.status = ");
query_builder.push_bind(status);
}
if let Some(purpose) = &filter.purpose {
query_builder.push(" AND f.purpose = ");
query_builder.push_bind(purpose);
}
if let Some(search) = &filter.search {
let search_pattern = format!("%{}%", search.to_lowercase());
query_builder.push(" AND LOWER(f.name) LIKE ");
query_builder.push_bind(search_pattern);
}
if let Some(api_key_ids) = &filter.api_key_ids {
query_builder.push(" AND f.api_key_id = ANY(");
query_builder.push_bind(api_key_ids.as_slice());
query_builder.push(")");
}
if let (Some(after_id), Some(after_ts)) = (&filter.after, after_created_at) {
let comparison = if filter.ascending { ">" } else { "<" };
query_builder.push(" AND (f.created_at ");
query_builder.push(comparison);
query_builder.push(" ");
query_builder.push_bind(after_ts);
query_builder.push(" OR (f.created_at = ");
query_builder.push_bind(after_ts);
query_builder.push(" AND f.id ");
query_builder.push(comparison);
query_builder.push(" ");
query_builder.push_bind(**after_id as Uuid);
query_builder.push("))");
}
let order_direction = if filter.ascending { "ASC" } else { "DESC" };
query_builder.push(" ORDER BY f.created_at ");
query_builder.push(order_direction);
query_builder.push(", f.id ");
query_builder.push(order_direction);
if let Some(limit) = filter.limit {
query_builder.push(" LIMIT ");
query_builder.push_bind(limit as i64);
}
let rows = query_builder
.build()
.fetch_all(self.pools.read())
.await
.map_err(|e| FusilladeError::Other(anyhow!("Failed to list files: {}", e)))?;
let mut files = Vec::new();
for row in rows {
let id: Uuid = row
.try_get("id")
.map_err(|e| FusilladeError::Other(anyhow!("Failed to read id: {}", e)))?;
let name: String = row
.try_get("name")
.map_err(|e| FusilladeError::Other(anyhow!("Failed to read name: {}", e)))?;
let description: Option<String> = row
.try_get("description")
.map_err(|e| FusilladeError::Other(anyhow!("Failed to read description: {}", e)))?;
let mut size_bytes: i64 = row
.try_get("size_bytes")
.map_err(|e| FusilladeError::Other(anyhow!("Failed to read size_bytes: {}", e)))?;
let size_finalized: bool = row.try_get("size_finalized").map_err(|e| {
FusilladeError::Other(anyhow!("Failed to read size_finalized: {}", e))
})?;
let status_str: String = row
.try_get("status")
.map_err(|e| FusilladeError::Other(anyhow!("Failed to read status: {}", e)))?;
let status = status_str
.parse::<crate::batch::FileStatus>()
.map_err(|e| {
FusilladeError::Other(anyhow!("Invalid file status '{}': {}", status_str, e))
})?;
let error_message: Option<String> = row.try_get("error_message").map_err(|e| {
FusilladeError::Other(anyhow!("Failed to read error_message: {}", e))
})?;
let purpose_str: Option<String> = row
.try_get("purpose")
.map_err(|e| FusilladeError::Other(anyhow!("Failed to read purpose: {}", e)))?;
let purpose = purpose_str
.map(|s| s.parse::<crate::batch::Purpose>())
.transpose()
.map_err(|e| FusilladeError::Other(anyhow!("Invalid purpose: {}", e)))?;
let expires_at: Option<chrono::DateTime<Utc>> = row
.try_get("expires_at")
.map_err(|e| FusilladeError::Other(anyhow!("Failed to read expires_at: {}", e)))?;
let deleted_at: Option<chrono::DateTime<Utc>> = row
.try_get("deleted_at")
.map_err(|e| FusilladeError::Other(anyhow!("Failed to read deleted_at: {}", e)))?;
let uploaded_by: Option<String> = row
.try_get("uploaded_by")
.map_err(|e| FusilladeError::Other(anyhow!("Failed to read uploaded_by: {}", e)))?;
let created_at: chrono::DateTime<Utc> = row
.try_get("created_at")
.map_err(|e| FusilladeError::Other(anyhow!("Failed to read created_at: {}", e)))?;
let updated_at: chrono::DateTime<Utc> = row
.try_get("updated_at")
.map_err(|e| FusilladeError::Other(anyhow!("Failed to read updated_at: {}", e)))?;
let api_key_id: Option<Uuid> = row
.try_get("api_key_id")
.map_err(|e| FusilladeError::Other(anyhow!("Failed to read api_key_id: {}", e)))?;
if let Some(estimated_size) =
self.calculate_virtual_file_size_from_row(&row, &purpose, size_finalized)?
{
size_bytes = estimated_size;
let file_id = FileId(id);
self.spawn_finalize_if_complete(&row, file_id, estimated_size);
}
let mut file = File {
id: FileId(id),
name,
description,
size_bytes,
size_finalized,
status,
error_message,
purpose,
expires_at,
deleted_at,
uploaded_by,
created_at,
updated_at,
api_key_id,
};
self.check_and_mark_expired(&mut file).await?;
files.push(file);
}
Ok(files)
}
#[tracing::instrument(skip(self), fields(file_id = %file_id))]
async fn delete_file(&self, file_id: FileId) -> Result<()> {
let mut tx =
self.pools.write().begin().await.map_err(|e| {
FusilladeError::Other(anyhow!("Failed to begin transaction: {}", e))
})?;
sqlx::query!(
r#"
UPDATE batches
SET cancelling_at = CASE
WHEN completed_at IS NULL AND failed_at IS NULL AND cancelled_at IS NULL
THEN COALESCE(cancelling_at, NOW())
ELSE cancelling_at
END,
cancelled_at = CASE
WHEN completed_at IS NULL AND failed_at IS NULL AND cancelled_at IS NULL
THEN COALESCE(cancelled_at, NOW())
ELSE cancelled_at
END,
file_id = NULL
WHERE file_id = $1
"#,
*file_id as Uuid,
)
.execute(&mut *tx)
.await
.map_err(|e| FusilladeError::Other(anyhow!("Failed to update batches: {}", e)))?;
sqlx::query!(
r#"
UPDATE batches
SET output_file_id = NULL
WHERE output_file_id = $1
"#,
*file_id as Uuid,
)
.execute(&mut *tx)
.await
.map_err(|e| {
FusilladeError::Other(anyhow!("Failed to clear output_file_id reference: {}", e))
})?;
sqlx::query!(
r#"
UPDATE batches
SET error_file_id = NULL
WHERE error_file_id = $1
"#,
*file_id as Uuid,
)
.execute(&mut *tx)
.await
.map_err(|e| {
FusilladeError::Other(anyhow!("Failed to clear error_file_id reference: {}", e))
})?;
let rows_affected = sqlx::query!(
r#"
UPDATE files
SET deleted_at = NOW(),
status = 'deleted'
WHERE id = $1
AND deleted_at IS NULL
"#,
*file_id as Uuid,
)
.execute(&mut *tx)
.await
.map_err(|e| FusilladeError::Other(anyhow!("Failed to soft-delete file: {}", e)))?
.rows_affected();
if rows_affected == 0 {
tx.rollback()
.await
.map_err(|e| FusilladeError::Other(anyhow!("Failed to rollback: {}", e)))?;
return Err(FusilladeError::Other(anyhow!("File not found")));
}
tx.commit()
.await
.map_err(|e| FusilladeError::Other(anyhow!("Failed to commit transaction: {}", e)))?;
Ok(())
}
#[tracing::instrument(level = "debug", skip(self, input), fields(file_id = %input.file_id))]
async fn create_batch(&self, input: BatchInput) -> Result<Batch> {
let file_id = input.file_id;
let batch = self.create_batch_record(input).await?;
if let Err(e) = self.populate_batch(batch.id, file_id).await {
let _ = self.mark_batch_failed(batch.id, &e.to_string()).await;
return Err(e);
}
self.get_batch_from_pool(batch.id, self.pools.write()).await
}
#[tracing::instrument(level = "debug", skip(self))]
async fn create_batch_record(&self, input: BatchInput) -> Result<Batch> {
let now = Utc::now();
let std_duration = humantime::parse_duration(&input.completion_window).map_err(|e| {
FusilladeError::Other(anyhow!(
"Invalid completion_window '{}': {}. Expected format like '24h', '7d', etc.",
input.completion_window,
e
))
})?;
let chrono_duration = chrono::Duration::from_std(std_duration).map_err(|e| {
FusilladeError::Other(anyhow!(
"Failed to convert completion_window duration: {}",
e
))
})?;
let expires_at = now.checked_add_signed(chrono_duration).ok_or_else(|| {
FusilladeError::Other(anyhow!(
"Expiration time overflow when calculating expires_at from completion_window '{}'",
input.completion_window
))
})?;
let total_requests = input.total_requests.unwrap_or(0);
let mut tx =
self.pools.write().begin().await.map_err(|e| {
FusilladeError::Other(anyhow!("Failed to begin transaction: {}", e))
})?;
let row = sqlx::query!(
r#"
INSERT INTO batches (file_id, endpoint, completion_window, metadata, created_by, expires_at, api_key_id, api_key, total_requests)
VALUES ($1, $2, $3, $4, COALESCE($5, ''), $6, $7, NULLIF(TRIM($8), ''), $9)
RETURNING id, created_at
"#,
*input.file_id as Uuid,
input.endpoint,
input.completion_window,
input.metadata,
input.created_by,
expires_at,
input.api_key_id,
input.api_key,
total_requests,
)
.fetch_one(&mut *tx)
.await
.map_err(|e| FusilladeError::Other(anyhow!("Failed to create batch record: {}", e)))?;
let output_file_id = self
.create_virtual_output_file(&mut tx, row.id, input.created_by.as_deref().unwrap_or(""))
.await?;
let error_file_id = self
.create_virtual_error_file(&mut tx, row.id, input.created_by.as_deref().unwrap_or(""))
.await?;
sqlx::query!(
r#"
UPDATE batches SET output_file_id = $2, error_file_id = $3 WHERE id = $1
"#,
row.id,
output_file_id,
error_file_id,
)
.execute(&mut *tx)
.await
.map_err(|e| {
FusilladeError::Other(anyhow!("Failed to update batch with file IDs: {}", e))
})?;
tx.commit()
.await
.map_err(|e| FusilladeError::Other(anyhow!("Failed to commit transaction: {}", e)))?;
Ok(Batch {
id: BatchId(row.id),
file_id: Some(input.file_id),
created_at: row.created_at,
metadata: input.metadata,
completion_window: input.completion_window,
endpoint: input.endpoint,
output_file_id: Some(FileId(output_file_id)),
error_file_id: Some(FileId(error_file_id)),
created_by: input.created_by.unwrap_or_default(),
expires_at,
cancelling_at: None,
errors: None,
total_requests,
pending_requests: 0,
in_progress_requests: 0,
completed_requests: 0,
failed_requests: 0,
canceled_requests: 0,
requests_started_at: None,
finalizing_at: None,
completed_at: None,
failed_at: None,
cancelled_at: None,
deleted_at: None,
notification_sent_at: None,
api_key_id: input.api_key_id,
})
}
#[tracing::instrument(level = "debug", skip(self), fields(batch_id = %batch_id))]
async fn populate_batch(&self, batch_id: BatchId, file_id: FileId) -> Result<()> {
let mut tx =
self.pools.write().begin().await.map_err(|e| {
FusilladeError::Other(anyhow!("Failed to begin transaction: {}", e))
})?;
let rows_affected = sqlx::query!(
r#"
INSERT INTO requests (batch_id, template_id, state, custom_id, retry_attempt, model)
SELECT $1, id, 'pending', custom_id, 0, model
FROM request_templates
WHERE file_id = $2
"#,
*batch_id as Uuid,
*file_id as Uuid,
)
.execute(&mut *tx)
.await
.map_err(|e| FusilladeError::Other(anyhow!("Failed to create requests: {}", e)))?
.rows_affected();
if rows_affected == 0 {
return Err(FusilladeError::ValidationError(
"Cannot populate batch from file with no templates".to_string(),
));
}
sqlx::query!(
r#"
UPDATE batches
SET total_requests = $2,
requests_started_at = NOW()
WHERE id = $1
"#,
*batch_id as Uuid,
rows_affected as i64
)
.execute(&mut *tx)
.await
.map_err(|e| FusilladeError::Other(anyhow!("Failed to update batch metadata: {}", e)))?;
tx.commit()
.await
.map_err(|e| FusilladeError::Other(anyhow!("Failed to commit transaction: {}", e)))?;
Ok(())
}
#[tracing::instrument(level = "debug", skip(self), fields(batch_id = %batch_id))]
async fn get_batch(&self, batch_id: BatchId) -> Result<Batch> {
self.get_batch_from_pool(batch_id, self.pools.read()).await
}
#[tracing::instrument(level = "debug", skip(self), fields(batch_id = %batch_id))]
async fn get_batch_status(&self, batch_id: BatchId) -> Result<BatchStatus> {
let mut query_builder = QueryBuilder::new(
r#"
SELECT
b.id as batch_id,
b.file_id,
f.name as file_name,
b.total_requests,
b.requests_started_at as started_at,
b.failed_at,
b.created_at,
COALESCE(counts.pending, 0)::BIGINT as pending_requests,
COALESCE(counts.in_progress, 0)::BIGINT as in_progress_requests,
COALESCE(counts.completed, 0)::BIGINT as completed_requests,
COALESCE(counts.failed, 0)::BIGINT as failed_requests,
COALESCE(counts.canceled, 0)::BIGINT as canceled_requests
FROM batches b
LEFT JOIN files f ON f.id = b.file_id
LEFT JOIN LATERAL (
SELECT
COUNT(*) FILTER (WHERE state = 'pending' AND b.cancelling_at IS NULL) as pending,
COUNT(*) FILTER (WHERE state IN ('claimed', 'processing') AND b.cancelling_at IS NULL) as in_progress,
COUNT(*) FILTER (WHERE state = 'completed') as completed,
COUNT(*) FILTER (WHERE state = 'failed') as failed,
COUNT(*) FILTER (WHERE state = 'canceled' OR (state IN ('pending', 'claimed', 'processing') AND b.cancelling_at IS NOT NULL)) as canceled
FROM requests
WHERE batch_id = b.id
) counts ON TRUE
WHERE b.id = "#,
);
query_builder.push_bind(*batch_id as Uuid);
query_builder.push(" AND b.deleted_at IS NULL");
let row = query_builder
.build()
.fetch_optional(self.pools.read())
.await
.map_err(|e| FusilladeError::Other(anyhow!("Failed to fetch batch status: {}", e)))?
.ok_or_else(|| FusilladeError::Other(anyhow!("Batch not found")))?;
Ok(batch_status_from_dynamic_row!(row))
}
async fn get_batch_by_output_file_id(
&self,
file_id: FileId,
file_type: OutputFileType,
) -> Result<Option<Batch>> {
let mut query_builder = QueryBuilder::new(
r#"
SELECT
b.id, b.file_id, b.endpoint, b.completion_window, b.metadata,
b.output_file_id, b.error_file_id, b.created_by, b.created_at,
b.expires_at, b.cancelling_at, b.errors,
b.total_requests,
b.requests_started_at,
b.finalizing_at,
b.completed_at,
b.failed_at,
b.cancelled_at,
b.deleted_at,
b.notification_sent_at,
b.api_key_id,
COALESCE(counts.pending, 0)::BIGINT as pending_requests,
COALESCE(counts.in_progress, 0)::BIGINT as in_progress_requests,
COALESCE(counts.completed, 0)::BIGINT as completed_requests,
COALESCE(counts.failed, 0)::BIGINT as failed_requests,
COALESCE(counts.canceled, 0)::BIGINT as canceled_requests
FROM batches b
LEFT JOIN LATERAL (
SELECT
COUNT(*) FILTER (WHERE state = 'pending' AND b.cancelling_at IS NULL) as pending,
COUNT(*) FILTER (WHERE state IN ('claimed', 'processing') AND b.cancelling_at IS NULL) as in_progress,
COUNT(*) FILTER (WHERE state = 'completed') as completed,
COUNT(*) FILTER (WHERE state = 'failed') as failed,
COUNT(*) FILTER (WHERE state = 'canceled' OR (state IN ('pending', 'claimed', 'processing') AND b.cancelling_at IS NOT NULL)) as canceled
FROM requests
WHERE batch_id = b.id
) counts ON TRUE
WHERE "#,
);
match file_type {
OutputFileType::Output => {
query_builder.push("b.output_file_id = ");
query_builder.push_bind(*file_id as Uuid);
query_builder.push(" AND b.deleted_at IS NULL");
let row = query_builder
.build()
.fetch_optional(self.pools.read())
.await
.map_err(|e| {
FusilladeError::Other(anyhow!("Failed to get batch by output file: {}", e))
})?;
Ok(row.map(|row| batch_from_dynamic_row!(row)))
}
OutputFileType::Error => {
query_builder.push("b.error_file_id = ");
query_builder.push_bind(*file_id as Uuid);
query_builder.push(" AND b.deleted_at IS NULL");
let row = query_builder
.build()
.fetch_optional(self.pools.read())
.await
.map_err(|e| {
FusilladeError::Other(anyhow!("Failed to get batch by error file: {}", e))
})?;
Ok(row.map(|row| batch_from_dynamic_row!(row)))
}
}
}
#[tracing::instrument(skip(self), fields(created_by = ?filter.created_by, limit = filter.limit))]
async fn list_batches(&self, filter: ListBatchesFilter) -> Result<Vec<Batch>> {
let ListBatchesFilter {
created_by,
search,
after,
limit,
api_key_ids,
status,
created_after,
created_before,
active_first,
} = filter;
let limit = limit.unwrap_or(100);
let priority_expr = "CASE WHEN b.completed_at IS NULL AND b.failed_at IS NULL \
AND b.cancelled_at IS NULL AND b.cancelling_at IS NULL THEN 0 ELSE 1 END";
let (after_created_at, after_id, after_priority) = if let Some(after_id) = after {
if active_first {
let row = sqlx::query!(
r#"
SELECT b.created_at,
CASE WHEN b.completed_at IS NULL AND b.failed_at IS NULL
AND b.cancelled_at IS NULL AND b.cancelling_at IS NULL
THEN 0 ELSE 1 END as "priority!: i32"
FROM batches b
WHERE b.id = $1
"#,
*after_id as Uuid,
)
.fetch_optional(self.pools.read())
.await
.map_err(|e| {
FusilladeError::Other(anyhow!("Failed to fetch after batch: {}", e))
})?;
match row {
Some(r) => (
Some(r.created_at),
Some(*after_id as Uuid),
Some(r.priority),
),
None => (None, Some(*after_id as Uuid), None),
}
} else {
let row = sqlx::query!(
r#"
SELECT created_at
FROM batches
WHERE id = $1
"#,
*after_id as Uuid,
)
.fetch_optional(self.pools.read())
.await
.map_err(|e| {
FusilladeError::Other(anyhow!("Failed to fetch after batch: {}", e))
})?;
(row.map(|r| r.created_at), Some(*after_id as Uuid), None)
}
} else {
(None, None, None)
};
let search_pattern = search.as_ref().map(|s| format!("%{}%", s.to_lowercase()));
let mut query_builder = QueryBuilder::new(
r#"
WITH filtered AS (
SELECT b.*, ("#,
);
query_builder.push(priority_expr);
query_builder.push(
r#") AS priority
FROM batches b
LEFT JOIN files f ON b.file_id = f.id
WHERE b.deleted_at IS NULL
AND ("#,
);
query_builder.push_bind(&created_by);
query_builder.push("::TEXT IS NULL OR b.created_by = ");
query_builder.push_bind(&created_by);
query_builder.push(")");
if active_first {
query_builder.push(" AND (");
query_builder.push_bind(after_priority);
query_builder.push("::INT IS NULL OR (");
query_builder.push(priority_expr);
query_builder.push(") > ");
query_builder.push_bind(after_priority);
query_builder.push(" OR ((");
query_builder.push(priority_expr);
query_builder.push(") = ");
query_builder.push_bind(after_priority);
query_builder.push(" AND b.created_at < ");
query_builder.push_bind(after_created_at);
query_builder.push(") OR ((");
query_builder.push(priority_expr);
query_builder.push(") = ");
query_builder.push_bind(after_priority);
query_builder.push(" AND b.created_at = ");
query_builder.push_bind(after_created_at);
query_builder.push(" AND b.id < ");
query_builder.push_bind(after_id);
query_builder.push("))");
} else {
query_builder.push(" AND (");
query_builder.push_bind(after_created_at);
query_builder.push("::TIMESTAMPTZ IS NULL OR b.created_at < ");
query_builder.push_bind(after_created_at);
query_builder.push(" OR (b.created_at = ");
query_builder.push_bind(after_created_at);
query_builder.push(" AND b.id < ");
query_builder.push_bind(after_id);
query_builder.push("))");
}
query_builder.push(" AND (");
query_builder.push_bind(&search_pattern);
query_builder.push("::TEXT IS NULL OR LOWER(b.metadata::text) LIKE ");
query_builder.push_bind(&search_pattern);
query_builder.push(" OR LOWER(f.name) LIKE ");
query_builder.push_bind(&search_pattern);
query_builder.push(" OR b.id::text LIKE ");
query_builder.push_bind(&search_pattern);
query_builder.push(")");
if let Some(api_key_ids) = &api_key_ids {
query_builder.push(" AND b.api_key_id = ANY(");
query_builder.push_bind(api_key_ids.as_slice());
query_builder.push(")");
}
if let Some(created_after) = &created_after {
query_builder.push(" AND b.created_at >= ");
query_builder.push_bind(*created_after);
}
if let Some(created_before) = &created_before {
query_builder.push(" AND b.created_at <= ");
query_builder.push_bind(*created_before);
}
if let Some(ref status) = status {
match status.as_str() {
"in_progress" => {
query_builder.push(" AND b.completed_at IS NULL AND b.failed_at IS NULL AND b.cancelled_at IS NULL AND b.cancelling_at IS NULL");
}
"completed" => {
query_builder.push(" AND b.completed_at IS NOT NULL");
}
"failed" => {
query_builder.push(" AND b.failed_at IS NOT NULL AND b.completed_at IS NULL");
}
"cancelled" => {
query_builder
.push(" AND (b.cancelled_at IS NOT NULL OR b.cancelling_at IS NOT NULL)");
}
"expired" => {
query_builder.push(
" AND b.expires_at IS NOT NULL AND (\
(b.expires_at < NOW() AND b.completed_at IS NULL AND b.failed_at IS NULL AND b.cancelled_at IS NULL AND b.cancelling_at IS NULL) \
OR (b.completed_at IS NOT NULL AND b.completed_at > b.expires_at) \
OR (b.failed_at IS NOT NULL AND b.failed_at > b.expires_at) \
OR (b.cancelled_at IS NOT NULL AND b.cancelled_at > b.expires_at)\
)",
);
}
unknown => {
return Err(FusilladeError::Other(anyhow!(
"Unknown batch status filter: '{}'. Valid values: in_progress, completed, failed, cancelled, expired",
unknown
)));
}
}
}
if active_first {
query_builder.push(" ORDER BY priority ASC, b.created_at DESC, b.id DESC LIMIT ");
} else {
query_builder.push(" ORDER BY b.created_at DESC, b.id DESC LIMIT ");
}
query_builder.push_bind(limit);
let phase2_order = if active_first {
"ORDER BY b.priority ASC, b.created_at DESC, b.id DESC"
} else {
"ORDER BY b.created_at DESC, b.id DESC"
};
query_builder.push(
r#"
)
SELECT
b.id, b.file_id, b.endpoint, b.completion_window, b.metadata,
b.output_file_id, b.error_file_id, b.created_by, b.created_at,
b.expires_at, b.cancelling_at, b.errors,
b.total_requests,
b.requests_started_at,
b.finalizing_at,
b.completed_at,
b.failed_at,
b.cancelled_at,
b.deleted_at,
b.notification_sent_at,
b.api_key_id,
COALESCE(counts.pending, 0)::BIGINT as pending_requests,
COALESCE(counts.in_progress, 0)::BIGINT as in_progress_requests,
COALESCE(counts.completed, 0)::BIGINT as completed_requests,
COALESCE(counts.failed, 0)::BIGINT as failed_requests,
COALESCE(counts.canceled, 0)::BIGINT as canceled_requests
FROM filtered b
LEFT JOIN LATERAL (
SELECT
COUNT(*) FILTER (WHERE state = 'pending' AND b.cancelling_at IS NULL) as pending,
COUNT(*) FILTER (WHERE state IN ('claimed', 'processing') AND b.cancelling_at IS NULL) as in_progress,
COUNT(*) FILTER (WHERE state = 'completed') as completed,
COUNT(*) FILTER (WHERE state = 'failed') as failed,
COUNT(*) FILTER (WHERE state = 'canceled' OR (state IN ('pending', 'claimed', 'processing') AND b.cancelling_at IS NOT NULL)) as canceled
FROM requests
WHERE batch_id = b.id
) counts ON TRUE
"#,
);
query_builder.push(phase2_order);
let rows = query_builder
.build()
.fetch_all(self.pools.read())
.await
.map_err(|e| FusilladeError::Other(anyhow!("Failed to list batches: {}", e)))?;
Ok(rows
.into_iter()
.map(|row| batch_from_dynamic_row!(row))
.collect())
}
async fn list_file_batches(&self, file_id: FileId) -> Result<Vec<BatchStatus>> {
let mut query_builder = QueryBuilder::new(
r#"
SELECT
b.id as batch_id,
b.file_id,
f.name as file_name,
b.total_requests,
b.requests_started_at as started_at,
b.failed_at,
b.created_at,
COALESCE(counts.pending, 0)::BIGINT as pending_requests,
COALESCE(counts.in_progress, 0)::BIGINT as in_progress_requests,
COALESCE(counts.completed, 0)::BIGINT as completed_requests,
COALESCE(counts.failed, 0)::BIGINT as failed_requests,
COALESCE(counts.canceled, 0)::BIGINT as canceled_requests
FROM batches b
LEFT JOIN files f ON f.id = b.file_id
LEFT JOIN LATERAL (
SELECT
COUNT(*) FILTER (WHERE state = 'pending' AND b.cancelling_at IS NULL) as pending,
COUNT(*) FILTER (WHERE state IN ('claimed', 'processing') AND b.cancelling_at IS NULL) as in_progress,
COUNT(*) FILTER (WHERE state = 'completed') as completed,
COUNT(*) FILTER (WHERE state = 'failed') as failed,
COUNT(*) FILTER (WHERE state = 'canceled' OR (state IN ('pending', 'claimed', 'processing') AND b.cancelling_at IS NOT NULL)) as canceled
FROM requests
WHERE batch_id = b.id
) counts ON TRUE
WHERE b.file_id = "#,
);
query_builder.push_bind(*file_id as Uuid);
query_builder.push(" AND b.deleted_at IS NULL ORDER BY b.created_at DESC");
let rows = query_builder
.build()
.fetch_all(self.pools.read())
.await
.map_err(|e| FusilladeError::Other(anyhow!("Failed to list batches: {}", e)))?;
Ok(rows
.into_iter()
.map(|row| batch_status_from_dynamic_row!(row))
.collect())
}
async fn get_cancelled_batch_ids(&self, batch_ids: &[BatchId]) -> Result<Vec<BatchId>> {
if batch_ids.is_empty() {
return Ok(Vec::new());
}
let uuids: Vec<Uuid> = batch_ids.iter().map(|id| **id).collect();
let rows = sqlx::query_scalar!(
r#"
SELECT id
FROM batches
WHERE id = ANY($1)
AND cancelling_at IS NOT NULL
AND deleted_at IS NULL
"#,
&uuids,
)
.fetch_all(self.pools.read())
.await
.map_err(|e| {
FusilladeError::Other(anyhow!("Failed to fetch cancelled batch IDs: {}", e))
})?;
Ok(rows.into_iter().map(BatchId::from).collect())
}
#[tracing::instrument(skip(self), fields(batch_id = %batch_id))]
async fn cancel_batch(&self, batch_id: BatchId) -> Result<()> {
let now = Utc::now();
sqlx::query!(
r#"
UPDATE batches
SET cancelling_at = $2,
cancelled_at = $2
WHERE id = $1 AND cancelling_at IS NULL
"#,
*batch_id as Uuid,
now,
)
.execute(self.pools.write())
.await
.map_err(|e| {
FusilladeError::Other(anyhow!("Failed to set cancellation timestamps: {}", e))
})?;
Ok(())
}
async fn delete_batch(&self, batch_id: BatchId) -> Result<()> {
let rows_affected = sqlx::query!(
r#"
UPDATE batches
SET deleted_at = NOW(),
cancelling_at = CASE
WHEN completed_at IS NULL AND failed_at IS NULL AND cancelled_at IS NULL
THEN COALESCE(cancelling_at, NOW())
ELSE cancelling_at
END,
cancelled_at = CASE
WHEN completed_at IS NULL AND failed_at IS NULL AND cancelled_at IS NULL
THEN COALESCE(cancelled_at, NOW())
ELSE cancelled_at
END
WHERE id = $1
AND deleted_at IS NULL
"#,
*batch_id as Uuid,
)
.execute(self.pools.write())
.await
.map_err(|e| FusilladeError::Other(anyhow!("Failed to soft-delete batch: {}", e)))?
.rows_affected();
if rows_affected == 0 {
return Err(FusilladeError::Other(anyhow!("Batch not found")));
}
Ok(())
}
async fn retry_failed_requests(&self, ids: Vec<RequestId>) -> Result<Vec<Result<()>>> {
tracing::debug!(count = ids.len(), "Retrying failed requests");
let get_results = self.get_requests(ids.clone()).await?;
let found_count = get_results.len();
if found_count != ids.len() {
let returned_ids: std::collections::HashSet<_> = get_results
.iter()
.filter_map(|r| r.as_ref().ok().map(|req| req.id()))
.collect();
let missing_ids: Vec<_> = ids.iter().filter(|id| !returned_ids.contains(id)).collect();
tracing::warn!(
missing_count = missing_ids.len(),
"Some requests not found, likely due to deleted templates"
);
}
let mut results = Vec::new();
for (id, request_result) in ids.iter().zip(get_results.into_iter()) {
let result = match request_result {
Ok(AnyRequest::Failed(req)) => {
let pending_request = Request {
state: Pending {
retry_attempt: 0,
not_before: None,
batch_expires_at: req.state.batch_expires_at,
},
data: req.data,
};
self.persist(&pending_request).await?;
Ok(())
}
Ok(_) => Err(crate::error::FusilladeError::InvalidState(
*id,
"non-failed state".to_string(),
"failed state".to_string(),
)),
Err(e) => Err(e),
};
results.push(result);
}
for _ in 0..(ids.len() - found_count) {
results.push(Err(FusilladeError::Other(anyhow!(
"Request not found - template may have been deleted"
))));
}
Ok(results)
}
async fn retry_failed_requests_for_batch(&self, batch_id: BatchId) -> Result<u64> {
tracing::debug!(%batch_id, "Retrying all failed requests for batch");
let mut tx =
self.pools.write().begin().await.map_err(|e| {
FusilladeError::Other(anyhow!("Failed to begin transaction: {}", e))
})?;
let result = sqlx::query!(
r#"
UPDATE requests
SET state = 'pending',
retry_attempt = 0,
not_before = NULL,
error = NULL,
failed_at = NULL,
daemon_id = NULL,
claimed_at = NULL,
started_at = NULL
WHERE batch_id = $1 AND state = 'failed'
"#,
*batch_id as Uuid,
)
.execute(&mut *tx)
.await
.map_err(|e| FusilladeError::Other(anyhow!("Failed to retry failed requests: {}", e)))?;
let count = result.rows_affected();
if count > 0 {
sqlx::query!(
r#"
UPDATE batches
SET completed_at = NULL,
failed_at = NULL,
finalizing_at = NULL,
notification_sent_at = NULL
WHERE id = $1
"#,
*batch_id as Uuid,
)
.execute(&mut *tx)
.await
.map_err(|e| {
FusilladeError::Other(anyhow!("Failed to reset batch terminal timestamps: {}", e))
})?;
}
tx.commit()
.await
.map_err(|e| FusilladeError::Other(anyhow!("Failed to commit transaction: {}", e)))?;
tracing::debug!(%batch_id, count, "Retried failed requests for batch");
Ok(count)
}
#[tracing::instrument(skip(self), fields(batch_id = %batch_id))]
async fn get_batch_requests(&self, batch_id: BatchId) -> Result<Vec<AnyRequest>> {
let rows = sqlx::query!(
r#"
SELECT
r.id, r.batch_id as "batch_id!", r.template_id as "template_id?", r.state,
t.custom_id as "custom_id?", t.endpoint as "endpoint?", t.method as "method?",
t.path as "path?", t.body as "body?", t.model as "model?", t.api_key as "api_key?",
r.retry_attempt, r.not_before, r.daemon_id, r.claimed_at, r.started_at,
r.response_status, r.response_body, r.completed_at, r.error, r.failed_at, r.canceled_at,
b.expires_at as batch_expires_at, r.routed_model
FROM requests r
LEFT JOIN active_request_templates t ON r.template_id = t.id
JOIN batches b ON r.batch_id = b.id
WHERE r.batch_id = $1 AND b.deleted_at IS NULL
ORDER BY r.created_at ASC
"#,
*batch_id as Uuid,
)
.fetch_all(self.pools.read())
.await
.map_err(|e| FusilladeError::Other(anyhow!("Failed to fetch batch executions: {}", e)))?;
let mut results = Vec::new();
for row in rows {
let request_id = RequestId(row.id);
let data = match (
row.template_id,
row.endpoint,
row.method,
row.path,
row.body,
row.model,
row.api_key,
) {
(
Some(template_id),
Some(endpoint),
Some(method),
Some(path),
Some(body),
Some(model),
Some(api_key),
) => RequestData {
id: request_id,
batch_id: BatchId(row.batch_id),
template_id: TemplateId(template_id),
custom_id: row.custom_id,
endpoint,
method,
path,
body,
model,
api_key,
created_by: String::new(),
batch_metadata: std::collections::HashMap::new(),
},
_ => {
tracing::debug!(request_id = %request_id, "Skipping batch request with deleted template");
continue;
}
};
let state = &row.state;
let any_request = match state.as_str() {
"pending" => AnyRequest::Pending(Request {
state: Pending {
retry_attempt: row.retry_attempt as u32,
not_before: row.not_before,
batch_expires_at: row.batch_expires_at,
},
data,
}),
"claimed" => AnyRequest::Claimed(Request {
state: Claimed {
daemon_id: DaemonId(row.daemon_id.ok_or_else(|| {
FusilladeError::Other(anyhow!(
"Missing daemon_id for claimed execution"
))
})?),
claimed_at: row.claimed_at.ok_or_else(|| {
FusilladeError::Other(anyhow!(
"Missing claimed_at for claimed execution"
))
})?,
retry_attempt: row.retry_attempt as u32,
batch_expires_at: row.batch_expires_at,
},
data,
}),
"processing" => {
let (_tx, rx) = tokio::sync::mpsc::channel(1);
let abort_handle = tokio::spawn(async {}).abort_handle();
AnyRequest::Processing(Request {
state: Processing {
daemon_id: DaemonId(row.daemon_id.ok_or_else(|| {
FusilladeError::Other(anyhow!(
"Missing daemon_id for processing execution"
))
})?),
claimed_at: row.claimed_at.ok_or_else(|| {
FusilladeError::Other(anyhow!(
"Missing claimed_at for processing execution"
))
})?,
started_at: row.started_at.ok_or_else(|| {
FusilladeError::Other(anyhow!(
"Missing started_at for processing execution"
))
})?,
retry_attempt: row.retry_attempt as u32,
batch_expires_at: row.batch_expires_at,
result_rx: Arc::new(Mutex::new(rx)),
abort_handle,
},
data,
})
}
"completed" => AnyRequest::Completed(Request {
state: Completed {
response_status: row.response_status.ok_or_else(|| {
FusilladeError::Other(anyhow!(
"Missing response_status for completed execution"
))
})? as u16,
response_body: row.response_body.ok_or_else(|| {
FusilladeError::Other(anyhow!(
"Missing response_body for completed execution"
))
})?,
claimed_at: row.claimed_at.ok_or_else(|| {
FusilladeError::Other(anyhow!(
"Missing claimed_at for completed execution"
))
})?,
started_at: row.started_at.ok_or_else(|| {
FusilladeError::Other(anyhow!(
"Missing started_at for completed execution"
))
})?,
completed_at: row.completed_at.ok_or_else(|| {
FusilladeError::Other(anyhow!(
"Missing completed_at for completed execution"
))
})?,
routed_model: row.routed_model.unwrap_or_else(|| data.model.clone()),
},
data,
}),
"failed" => {
let error_json = row.error.ok_or_else(|| {
FusilladeError::Other(anyhow!("Missing error for failed execution"))
})?;
let reason: FailureReason =
serde_json::from_str(&error_json).unwrap_or_else(|_| {
FailureReason::NetworkError {
error: error_json.clone(),
}
});
AnyRequest::Failed(Request {
state: Failed {
reason,
failed_at: row.failed_at.ok_or_else(|| {
FusilladeError::Other(anyhow!(
"Missing failed_at for failed execution"
))
})?,
retry_attempt: row.retry_attempt as u32,
batch_expires_at: row.batch_expires_at,
routed_model: row.routed_model.unwrap_or_else(|| data.model.clone()),
},
data,
})
}
"canceled" => AnyRequest::Canceled(Request {
state: Canceled {
canceled_at: row.canceled_at.ok_or_else(|| {
FusilladeError::Other(anyhow!(
"Missing canceled_at for canceled execution"
))
})?,
},
data,
}),
_ => {
return Err(FusilladeError::Other(anyhow!("Unknown state: {}", state)));
}
};
results.push(any_request);
}
Ok(results)
}
#[tracing::instrument(skip(self), fields(batch_id = %batch_id, search = ?search, status = ?status))]
fn get_batch_results_stream(
&self,
batch_id: BatchId,
offset: usize,
search: Option<String>,
status: Option<String>,
) -> Pin<Box<dyn Stream<Item = Result<crate::batch::BatchResultItem>> + Send>> {
let pool = self.pools.read().clone();
let (tx, rx) = mpsc::channel(self.download_buffer_size);
let offset = offset as i64;
tokio::spawn(async move {
Self::stream_batch_results(pool, batch_id, offset, search, status, tx).await;
});
Box::pin(ReceiverStream::new(rx))
}
}
impl<P: PoolProvider, H: HttpClient + 'static> PostgresRequestManager<P, H> {
async fn get_batch_from_pool(&self, batch_id: BatchId, pool: &PgPool) -> Result<Batch> {
let mut query_builder = QueryBuilder::new(
r#"
SELECT
b.id, b.file_id, b.endpoint, b.completion_window, b.metadata,
b.output_file_id, b.error_file_id, b.created_by, b.created_at,
b.expires_at, b.cancelling_at, b.errors,
b.total_requests,
b.requests_started_at,
b.finalizing_at,
b.completed_at,
b.failed_at,
b.cancelled_at,
b.deleted_at,
b.notification_sent_at,
b.api_key_id,
COALESCE(counts.pending, 0)::BIGINT as pending_requests,
COALESCE(counts.in_progress, 0)::BIGINT as in_progress_requests,
COALESCE(counts.completed, 0)::BIGINT as completed_requests,
COALESCE(counts.failed, 0)::BIGINT as failed_requests,
COALESCE(counts.canceled, 0)::BIGINT as canceled_requests
FROM batches b
LEFT JOIN LATERAL (
SELECT
COUNT(*) FILTER (WHERE state = 'pending' AND b.cancelling_at IS NULL) as pending,
COUNT(*) FILTER (WHERE state IN ('claimed', 'processing') AND b.cancelling_at IS NULL) as in_progress,
COUNT(*) FILTER (WHERE state = 'completed') as completed,
COUNT(*) FILTER (WHERE state = 'failed') as failed,
COUNT(*) FILTER (WHERE state = 'canceled' OR (state IN ('pending', 'claimed', 'processing') AND b.cancelling_at IS NOT NULL)) as canceled
FROM requests
WHERE batch_id = b.id
) counts ON TRUE
WHERE b.id = "#,
);
query_builder.push_bind(*batch_id as Uuid);
query_builder.push(" AND b.deleted_at IS NULL");
let row = query_builder
.build()
.fetch_optional(pool)
.await
.map_err(|e| FusilladeError::Other(anyhow!("Failed to fetch batch: {}", e)))?
.ok_or_else(|| FusilladeError::Other(anyhow!("Batch not found")))?;
let pending_requests: i64 = row.get("pending_requests");
let in_progress_requests: i64 = row.get("in_progress_requests");
let completed_requests: i64 = row.get("completed_requests");
let failed_requests: i64 = row.get("failed_requests");
let canceled_requests: i64 = row.get("canceled_requests");
let total_requests: i64 = row.get("total_requests");
let completed_at: Option<DateTime<Utc>> = row.get("completed_at");
let failed_at: Option<DateTime<Utc>> = row.get("failed_at");
let cancelled_at: Option<DateTime<Utc>> = row.get("cancelled_at");
let finalizing_at_db: Option<DateTime<Utc>> = row.get("finalizing_at");
let terminal_count = completed_requests + failed_requests + canceled_requests;
let is_terminal = terminal_count == total_requests && total_requests > 0;
let (finalizing_at, completed_at, failed_at) = if is_terminal
&& completed_at.is_none()
&& failed_at.is_none()
&& cancelled_at.is_none()
{
let now = Utc::now();
let (finalizing, completed, failed) = if completed_requests > 0 {
(Some(now), Some(now), None)
} else {
(Some(now), None, Some(now))
};
sqlx::query!(
r#"
UPDATE batches
SET finalizing_at = COALESCE(finalizing_at, $2),
completed_at = COALESCE(completed_at, $3),
failed_at = COALESCE(failed_at, $4)
WHERE id = $1
"#,
*batch_id as Uuid,
finalizing,
completed,
failed,
)
.execute(self.pools.write()) .await
.map_err(|e| {
FusilladeError::Other(anyhow!("Failed to update terminal timestamps: {}", e))
})?;
(finalizing, completed, failed)
} else {
(finalizing_at_db, completed_at, failed_at)
};
Ok(Batch {
id: BatchId(row.get("id")),
file_id: row.get::<Option<Uuid>, _>("file_id").map(FileId),
created_at: row.get("created_at"),
metadata: row.get("metadata"),
completion_window: row.get("completion_window"),
endpoint: row.get("endpoint"),
output_file_id: row.get::<Option<Uuid>, _>("output_file_id").map(FileId),
error_file_id: row.get::<Option<Uuid>, _>("error_file_id").map(FileId),
created_by: row.get("created_by"),
expires_at: row.get("expires_at"),
cancelling_at: row.get("cancelling_at"),
errors: row.get("errors"),
total_requests,
pending_requests,
in_progress_requests,
completed_requests,
failed_requests,
canceled_requests,
requests_started_at: row.get("requests_started_at"),
finalizing_at,
completed_at,
failed_at,
cancelled_at,
deleted_at: row.get("deleted_at"),
notification_sent_at: row.get("notification_sent_at"),
api_key_id: row.get::<Option<Uuid>, _>("api_key_id"),
})
}
pub async fn poll_completed_batches(&self) -> Result<Vec<BatchNotification>> {
let rows = sqlx::query!(
r#"
-- Step 1: Find candidate batches that are terminal by count
WITH candidates AS (
SELECT b.id,
COALESCE(counts.completed, 0)::BIGINT as completed_requests,
COALESCE(counts.failed, 0)::BIGINT as failed_requests,
COALESCE(counts.canceled, 0)::BIGINT as canceled_requests,
COALESCE(counts.pending, 0)::BIGINT as pending_requests,
COALESCE(counts.in_progress, 0)::BIGINT as in_progress_requests
FROM batches b
-- Count requests by state for each batch
LEFT JOIN LATERAL (
SELECT
COUNT(*) FILTER (WHERE state = 'completed') as completed,
COUNT(*) FILTER (WHERE state = 'failed') as failed,
-- Canceled = explicitly canceled OR will be canceled (pending/in-progress with cancelling_at set)
COUNT(*) FILTER (WHERE state = 'canceled' OR (state IN ('pending', 'claimed', 'processing') AND b.cancelling_at IS NOT NULL)) as canceled,
COUNT(*) FILTER (WHERE state = 'pending' AND b.cancelling_at IS NULL) as pending,
COUNT(*) FILTER (WHERE state IN ('claimed', 'processing') AND b.cancelling_at IS NULL) as in_progress
FROM requests WHERE batch_id = b.id
) counts ON TRUE
WHERE b.notification_sent_at IS NULL -- Not yet notified
AND b.deleted_at IS NULL -- Not deleted
AND b.cancelling_at IS NULL -- Not canceled (don't email on user-canceled batches)
AND b.total_requests > 0 -- Has requests
AND (
-- Terminal by count: all requests reached terminal state
COALESCE(counts.completed, 0) + COALESCE(counts.failed, 0) + COALESCE(counts.canceled, 0) = b.total_requests
)
),
-- Step 2: Atomically claim batches and set terminal timestamps
updated AS (
UPDATE batches b
SET notification_sent_at = NOW(), -- Claim for notification (prevents duplicates)
-- Set terminal timestamps via COALESCE (no-op if already set by get_batch)
finalizing_at = COALESCE(b.finalizing_at, NOW()),
completed_at = COALESCE(b.completed_at,
CASE WHEN c.completed_requests > 0 THEN NOW() END),
failed_at = COALESCE(b.failed_at,
CASE WHEN c.completed_requests = 0 THEN NOW() END)
FROM candidates c
WHERE b.id = c.id
AND b.notification_sent_at IS NULL -- Re-check to handle concurrent pollers
RETURNING b.id, b.file_id, b.endpoint, b.completion_window, b.metadata,
b.output_file_id, b.error_file_id, b.created_by, b.created_at,
b.expires_at, b.cancelling_at, b.errors, b.total_requests,
b.requests_started_at, b.finalizing_at, b.completed_at,
b.failed_at, b.cancelled_at, b.deleted_at, b.notification_sent_at, b.api_key_id,
c.completed_requests, c.failed_requests, c.canceled_requests,
c.pending_requests, c.in_progress_requests
)
SELECT u.*,
f.name as input_file_name,
f.description as input_file_description,
(SELECT string_agg(DISTINCT r.model, ', ') FROM requests r WHERE r.batch_id = u.id) as model
FROM updated u
LEFT JOIN files f ON f.id = u.file_id
"#
)
.fetch_all(self.pools.write())
.await
.map_err(|e| {
FusilladeError::Other(anyhow!("Failed to poll completed batches: {}", e))
})?;
Ok(rows
.into_iter()
.map(|row| BatchNotification {
batch: Batch {
id: BatchId(row.id),
file_id: row.file_id.map(FileId),
endpoint: row.endpoint,
completion_window: row.completion_window,
metadata: row.metadata,
output_file_id: row.output_file_id.map(FileId),
error_file_id: row.error_file_id.map(FileId),
created_by: row.created_by,
created_at: row.created_at,
expires_at: row.expires_at,
cancelling_at: row.cancelling_at,
errors: row.errors,
total_requests: row.total_requests,
requests_started_at: row.requests_started_at,
finalizing_at: row.finalizing_at,
completed_at: row.completed_at,
failed_at: row.failed_at,
cancelled_at: row.cancelled_at,
deleted_at: row.deleted_at,
notification_sent_at: row.notification_sent_at,
api_key_id: row.api_key_id,
pending_requests: row.pending_requests.unwrap_or(0),
in_progress_requests: row.in_progress_requests.unwrap_or(0),
completed_requests: row.completed_requests.unwrap_or(0),
failed_requests: row.failed_requests.unwrap_or(0),
canceled_requests: row.canceled_requests.unwrap_or(0),
},
model: row.model.unwrap_or_default(),
input_file_name: row.input_file_name,
input_file_description: row.input_file_description,
})
.collect())
}
async fn insert_template_batch(
tx: &mut sqlx::Transaction<'_, sqlx::Postgres>,
file_id: Uuid,
templates: &[(RequestTemplateInput, i32)],
) -> Result<()> {
if templates.is_empty() {
return Ok(());
}
let custom_ids: Vec<Option<&str>> = templates
.iter()
.map(|(t, _)| t.custom_id.as_deref())
.collect();
let endpoints: Vec<&str> = templates.iter().map(|(t, _)| t.endpoint.as_str()).collect();
let methods: Vec<&str> = templates.iter().map(|(t, _)| t.method.as_str()).collect();
let paths: Vec<&str> = templates.iter().map(|(t, _)| t.path.as_str()).collect();
let bodies: Vec<&str> = templates.iter().map(|(t, _)| t.body.as_str()).collect();
let models: Vec<&str> = templates.iter().map(|(t, _)| t.model.as_str()).collect();
let api_keys: Vec<&str> = templates.iter().map(|(t, _)| t.api_key.as_str()).collect();
let line_numbers: Vec<i32> = templates.iter().map(|(_, line)| *line).collect();
let body_byte_sizes: Vec<i64> =
templates.iter().map(|(t, _)| t.body.len() as i64).collect();
sqlx::query!(
r#"
INSERT INTO request_templates (file_id, custom_id, endpoint, method, path, body, model, api_key, line_number, body_byte_size)
SELECT $1, custom_id, endpoint, method, path, body, model, api_key, line_number, body_byte_size
FROM UNNEST(
$2::text[], $3::text[], $4::text[], $5::text[], $6::text[],
$7::text[], $8::text[], $9::int[], $10::bigint[]
) AS t(custom_id, endpoint, method, path, body, model, api_key, line_number, body_byte_size)
"#,
file_id,
&custom_ids as &[Option<&str>],
&endpoints as &[&str],
&methods as &[&str],
&paths as &[&str],
&bodies as &[&str],
&models as &[&str],
&api_keys as &[&str],
&line_numbers as &[i32],
&body_byte_sizes as &[i64],
)
.execute(&mut **tx)
.await
.map_err(|e| FusilladeError::Other(anyhow!("Failed to batch insert templates: {}", e)))?;
Ok(())
}
async fn stream_request_templates(
pool: sqlx::PgPool,
file_id: FileId,
offset: i64,
search: Option<String>,
tx: mpsc::Sender<Result<FileContentItem>>,
) {
const BATCH_SIZE: i64 = 1000;
let mut last_line_number: i32 = -1;
let mut is_first_batch = true;
let search_pattern = search.map(|s| format!("%{}%", s.to_lowercase()));
loop {
let (line_filter, offset_val) = if is_first_batch {
(-1i32, offset)
} else {
(last_line_number, 0i64)
};
is_first_batch = false;
let template_batch = sqlx::query!(
r#"
SELECT custom_id, endpoint, method, path, body, model, api_key, line_number
FROM request_templates
WHERE file_id = $1 AND ($2 = -1 OR line_number > $2)
AND ($5::text IS NULL OR LOWER(custom_id) LIKE $5)
ORDER BY line_number ASC
OFFSET $3
LIMIT $4
"#,
*file_id as Uuid,
line_filter,
offset_val,
BATCH_SIZE,
search_pattern.as_deref(),
)
.fetch_all(&pool)
.await;
match template_batch {
Ok(templates) => {
if templates.is_empty() {
break;
}
tracing::debug!(
"Fetched batch of {} templates, line_numbers {}-{}",
templates.len(),
templates.first().map(|r| r.line_number).unwrap_or(0),
templates.last().map(|r| r.line_number).unwrap_or(0)
);
for row in templates {
last_line_number = row.line_number;
let template = RequestTemplateInput {
custom_id: row.custom_id,
endpoint: row.endpoint,
method: row.method,
path: row.path,
body: row.body,
model: row.model,
api_key: row.api_key,
};
if tx
.send(Ok(FileContentItem::Template(template)))
.await
.is_err()
{
return;
}
}
}
Err(e) => {
let _ = tx
.send(Err(FusilladeError::Other(anyhow!(
"Failed to fetch template batch: {}",
e
))))
.await;
return;
}
}
}
}
async fn stream_batch_output(
pool: sqlx::PgPool,
file_id: FileId,
offset: i64,
search: Option<String>,
tx: mpsc::Sender<Result<FileContentItem>>,
) {
let batch_result = sqlx::query!(
r#"
SELECT id
FROM batches
WHERE output_file_id = $1
"#,
*file_id as Uuid,
)
.fetch_one(&pool)
.await;
let batch_id = match batch_result {
Ok(row) => row.id,
Err(e) => {
let _ = tx
.send(Err(FusilladeError::Other(anyhow!(
"Failed to find batch for output file: {}",
e
))))
.await;
return;
}
};
const BATCH_SIZE: i64 = 1000;
let mut last_completed_at: Option<chrono::DateTime<chrono::Utc>> = None;
let mut last_id: Uuid = Uuid::nil();
let mut is_first_batch = true;
let search_pattern = search.map(|s| format!("%{}%", s.to_lowercase()));
loop {
let (cursor_time, cursor_id, offset_val) = if is_first_batch {
(None, Uuid::nil(), offset)
} else {
(last_completed_at, last_id, 0i64)
};
is_first_batch = false;
let request_batch = sqlx::query!(
r#"
SELECT id, custom_id, response_status, response_body, completed_at
FROM requests
WHERE batch_id = $1
AND state = 'completed'
AND ($2::TIMESTAMPTZ IS NULL OR completed_at > $2 OR (completed_at = $2 AND id > $3))
AND ($6::text IS NULL OR LOWER(custom_id) LIKE $6)
ORDER BY completed_at ASC, id ASC
OFFSET $4
LIMIT $5
"#,
batch_id,
cursor_time,
cursor_id,
offset_val,
BATCH_SIZE,
search_pattern.as_deref(),
)
.fetch_all(&pool)
.await;
match request_batch {
Ok(requests) => {
if requests.is_empty() {
break;
}
tracing::debug!("Fetched batch of {} completed requests", requests.len());
for row in requests {
last_completed_at = row.completed_at;
last_id = row.id;
let response_body: serde_json::Value = match &row.response_body {
Some(body) => match serde_json::from_str(body) {
Ok(json) => json,
Err(e) => {
tracing::warn!("Failed to parse response body as JSON: {}", e);
serde_json::Value::String(body.to_string())
}
},
None => serde_json::Value::Null,
};
let output_item = BatchOutputItem {
id: format!("batch_req_{}", row.id),
custom_id: row.custom_id,
response: BatchResponseDetails {
status_code: row.response_status.unwrap_or(200),
request_id: None, body: response_body,
},
error: None,
};
if tx
.send(Ok(FileContentItem::Output(output_item)))
.await
.is_err()
{
return;
}
}
}
Err(e) => {
let _ = tx
.send(Err(FusilladeError::Other(anyhow!(
"Failed to fetch completed requests: {}",
e
))))
.await;
return;
}
}
}
}
async fn stream_batch_error(
pool: sqlx::PgPool,
file_id: FileId,
offset: i64,
search: Option<String>,
tx: mpsc::Sender<Result<FileContentItem>>,
) {
let batch_result = sqlx::query!(
r#"
SELECT id, expires_at
FROM batches
WHERE error_file_id = $1
"#,
*file_id as Uuid,
)
.fetch_one(&pool)
.await;
let (batch_id, _expires_at) = match batch_result {
Ok(row) => (row.id, row.expires_at),
Err(e) => {
let _ = tx
.send(Err(FusilladeError::Other(anyhow!(
"Failed to find batch for error file: {}",
e
))))
.await;
return;
}
};
const BATCH_SIZE: i64 = 1000;
let mut last_failed_at: Option<chrono::DateTime<chrono::Utc>> = None;
let mut last_id: Uuid = Uuid::nil();
let mut is_first_batch = true;
let search_pattern = search.map(|s| format!("%{}%", s.to_lowercase()));
loop {
let (cursor_time, cursor_id, offset_val) = if is_first_batch {
(None, Uuid::nil(), offset)
} else {
(last_failed_at, last_id, 0i64)
};
is_first_batch = false;
let mut query_builder = QueryBuilder::new(
r#"
SELECT id, custom_id, error, failed_at
FROM requests
WHERE batch_id = "#,
);
query_builder.push_bind(batch_id);
query_builder.push(" AND state = 'failed' AND (");
query_builder.push_bind(cursor_time);
query_builder.push("::TIMESTAMPTZ IS NULL OR failed_at > ");
query_builder.push_bind(cursor_time);
query_builder.push(" OR (failed_at = ");
query_builder.push_bind(cursor_time);
query_builder.push(" AND id > ");
query_builder.push_bind(cursor_id);
query_builder.push(")) AND (");
query_builder.push_bind(search_pattern.as_deref());
query_builder.push("::text IS NULL OR LOWER(custom_id) LIKE ");
query_builder.push_bind(search_pattern.as_deref());
query_builder.push(")");
query_builder.push(" ORDER BY failed_at ASC, id ASC OFFSET ");
query_builder.push_bind(offset_val);
query_builder.push(" LIMIT ");
query_builder.push_bind(BATCH_SIZE);
let request_batch = query_builder.build().fetch_all(&pool).await;
match request_batch {
Ok(requests) => {
if requests.is_empty() {
break;
}
tracing::debug!("Fetched batch of {} failed requests", requests.len());
for row in requests {
let id: Uuid = row.get("id");
let custom_id: Option<String> = row.get("custom_id");
let error: Option<String> = row.get("error");
let failed_at: Option<DateTime<Utc>> = row.get("failed_at");
last_failed_at = failed_at;
last_id = id;
let error_item = BatchErrorItem {
id: format!("batch_req_{}", id),
custom_id,
response: None,
error: BatchErrorDetails {
code: None, message: error.unwrap_or_else(|| "Unknown error".to_string()),
},
};
if tx
.send(Ok(FileContentItem::Error(error_item)))
.await
.is_err()
{
return;
}
}
}
Err(e) => {
let _ = tx
.send(Err(FusilladeError::Other(anyhow!(
"Failed to fetch failed requests: {}",
e
))))
.await;
return;
}
}
}
}
async fn stream_batch_results(
pool: sqlx::PgPool,
batch_id: BatchId,
offset: i64,
search: Option<String>,
status: Option<String>,
tx: mpsc::Sender<Result<crate::batch::BatchResultItem>>,
) {
use crate::batch::{BatchResultItem, BatchResultStatus};
let (file_id, _expires_at) = match sqlx::query!(
r#"SELECT file_id, expires_at FROM batches WHERE id = $1 AND deleted_at IS NULL"#,
*batch_id as Uuid,
)
.fetch_optional(&pool)
.await
{
Ok(Some(row)) => {
if let Some(fid) = row.file_id {
(fid, row.expires_at)
} else {
let _ = tx
.send(Err(FusilladeError::Other(anyhow!(
"Batch has no associated file_id"
))))
.await;
return;
}
}
Ok(None) => {
let _ = tx
.send(Err(FusilladeError::Other(anyhow!("Batch not found"))))
.await;
return;
}
Err(e) => {
let _ = tx
.send(Err(FusilladeError::Other(anyhow!(
"Failed to fetch batch: {}",
e
))))
.await;
return;
}
};
const BATCH_SIZE: i64 = 1000;
let mut last_line_number: i32 = -1;
let mut is_first_batch = true;
let search_pattern = search.map(|s| format!("%{}%", s.to_lowercase()));
let state_filter: Option<Vec<String>> = status.map(|s| match s.as_str() {
"in_progress" => vec!["claimed".to_string(), "processing".to_string()],
other => vec![other.to_string()],
});
loop {
let (line_filter, offset_val) = if is_first_batch {
(-1i32, offset)
} else {
(last_line_number, 0i64)
};
is_first_batch = false;
let mut query_builder = QueryBuilder::new(
r#"
SELECT
r.id,
r.custom_id,
r.model,
r.state,
t.body as input_body,
r.response_body,
r.error,
t.line_number
FROM request_templates t
JOIN requests r ON r.template_id = t.id AND r.batch_id = "#,
);
query_builder.push_bind(*batch_id as Uuid);
query_builder.push(" WHERE t.file_id = ");
query_builder.push_bind(file_id);
query_builder.push(" AND (");
query_builder.push_bind(line_filter);
query_builder.push(" = -1 OR t.line_number > ");
query_builder.push_bind(line_filter);
query_builder.push(") AND (");
query_builder.push_bind(search_pattern.as_deref());
query_builder.push("::text IS NULL OR LOWER(r.custom_id) LIKE ");
query_builder.push_bind(search_pattern.as_deref());
query_builder.push(") AND (");
query_builder.push_bind(state_filter.as_deref());
query_builder.push("::text[] IS NULL OR r.state = ANY(");
query_builder.push_bind(state_filter.as_deref());
query_builder.push("))");
query_builder.push(" ORDER BY t.line_number ASC OFFSET ");
query_builder.push_bind(offset_val);
query_builder.push(" LIMIT ");
query_builder.push_bind(BATCH_SIZE);
let request_batch = query_builder.build().fetch_all(&pool).await;
match request_batch {
Ok(requests) => {
if requests.is_empty() {
break;
}
tracing::debug!("Fetched batch of {} results", requests.len());
for row in requests {
let line_number: i32 = row.get("line_number");
last_line_number = line_number;
let input_body_str: String = row.get("input_body");
let response_body_opt: Option<String> = row.get("response_body");
let state: String = row.get("state");
let id: Uuid = row.get("id");
let custom_id: Option<String> = row.get("custom_id");
let model: String = row.get("model");
let error: Option<String> = row.get("error");
let input_body: serde_json::Value = serde_json::from_str(&input_body_str)
.unwrap_or_else(|_| serde_json::Value::String(input_body_str.clone()));
let response_body: Option<serde_json::Value> =
response_body_opt.as_ref().map(|body| {
serde_json::from_str(body)
.unwrap_or_else(|_| serde_json::Value::String(body.to_string()))
});
let status = match state.as_str() {
"completed" => BatchResultStatus::Completed,
"failed" => BatchResultStatus::Failed,
"pending" => BatchResultStatus::Pending,
"claimed" | "processing" => BatchResultStatus::InProgress,
_ => BatchResultStatus::Pending, };
let result_item = BatchResultItem {
id: id.to_string(),
custom_id,
model,
input_body,
response_body,
error,
status,
};
if tx.send(Ok(result_item)).await.is_err() {
return;
}
}
}
Err(e) => {
let _ = tx
.send(Err(FusilladeError::Other(anyhow!(
"Failed to fetch batch results: {}",
e
))))
.await;
return;
}
}
}
}
async fn create_virtual_output_file(
&self,
tx: &mut sqlx::Transaction<'_, sqlx::Postgres>,
batch_id: Uuid,
created_by: &str,
) -> Result<Uuid> {
let name = format!("batch-{}-output.jsonl", batch_id);
let description = format!("Output file for batch {}", batch_id);
let file_id = sqlx::query_scalar!(
r#"
INSERT INTO files (name, description, size_bytes, size_finalized, status, purpose, uploaded_by)
VALUES ($1, $2, 0, FALSE, 'processed', 'batch_output', NULLIF($3, ''))
RETURNING id
"#,
name,
description,
created_by,
)
.fetch_one(&mut **tx)
.await
.map_err(|e| FusilladeError::Other(anyhow!("Failed to create output file: {}", e)))?;
Ok(file_id)
}
async fn create_virtual_error_file(
&self,
tx: &mut sqlx::Transaction<'_, sqlx::Postgres>,
batch_id: Uuid,
created_by: &str,
) -> Result<Uuid> {
let name = format!("batch-{}-error.jsonl", batch_id);
let description = format!("Error file for batch {}", batch_id);
let file_id = sqlx::query_scalar!(
r#"
INSERT INTO files (name, description, size_bytes, size_finalized, status, purpose, uploaded_by)
VALUES ($1, $2, 0, FALSE, 'processed', 'batch_error', NULLIF($3, ''))
RETURNING id
"#,
name,
description,
created_by,
)
.fetch_one(&mut **tx)
.await
.map_err(|e| FusilladeError::Other(anyhow!("Failed to create error file: {}", e)))?;
Ok(file_id)
}
}
#[async_trait]
impl<P: PoolProvider, H: HttpClient> DaemonStorage for PostgresRequestManager<P, H> {
async fn persist_daemon<T: DaemonState + Clone>(&self, record: &DaemonRecord<T>) -> Result<()>
where
AnyDaemonRecord: From<DaemonRecord<T>>,
{
let any_daemon = AnyDaemonRecord::from(record.clone());
match any_daemon {
AnyDaemonRecord::Initializing(daemon) => {
sqlx::query!(
r#"
INSERT INTO daemons (
id, status, hostname, pid, version, config_snapshot,
started_at, last_heartbeat, stopped_at,
requests_processed, requests_failed, requests_in_flight
) VALUES ($1, 'initializing', $2, $3, $4, $5, $6, NULL, NULL, 0, 0, 0)
ON CONFLICT (id) DO UPDATE SET
status = 'initializing',
started_at = $6,
updated_at = NOW()
"#,
*daemon.data.id as Uuid,
daemon.data.hostname,
daemon.data.pid,
daemon.data.version,
daemon.data.config_snapshot,
daemon.state.started_at,
)
.execute(self.pools.write())
.await
.map_err(|e| FusilladeError::Other(anyhow!("Failed to persist daemon: {}", e)))?;
}
AnyDaemonRecord::Running(daemon) => {
sqlx::query!(
r#"
INSERT INTO daemons (
id, status, hostname, pid, version, config_snapshot,
started_at, last_heartbeat, stopped_at,
requests_processed, requests_failed, requests_in_flight
) VALUES ($1, 'running', $2, $3, $4, $5, $6, $7, NULL, $8, $9, $10)
ON CONFLICT (id) DO UPDATE SET
status = 'running',
last_heartbeat = $7,
requests_processed = $8,
requests_failed = $9,
requests_in_flight = $10,
updated_at = NOW()
"#,
*daemon.data.id as Uuid,
daemon.data.hostname,
daemon.data.pid,
daemon.data.version,
daemon.data.config_snapshot,
daemon.state.started_at,
daemon.state.last_heartbeat,
daemon.state.stats.requests_processed as i64,
daemon.state.stats.requests_failed as i64,
daemon.state.stats.requests_in_flight as i32,
)
.execute(self.pools.write())
.await
.map_err(|e| FusilladeError::Other(anyhow!("Failed to persist daemon: {}", e)))?;
}
AnyDaemonRecord::Dead(daemon) => {
sqlx::query!(
r#"
INSERT INTO daemons (
id, status, hostname, pid, version, config_snapshot,
started_at, last_heartbeat, stopped_at,
requests_processed, requests_failed, requests_in_flight
) VALUES ($1, 'dead', $2, $3, $4, $5, $6, NULL, $7, $8, $9, $10)
ON CONFLICT (id) DO UPDATE SET
status = 'dead',
stopped_at = $7,
requests_processed = $8,
requests_failed = $9,
requests_in_flight = $10,
updated_at = NOW()
"#,
*daemon.data.id as Uuid,
daemon.data.hostname,
daemon.data.pid,
daemon.data.version,
daemon.data.config_snapshot,
daemon.state.started_at,
daemon.state.stopped_at,
daemon.state.final_stats.requests_processed as i64,
daemon.state.final_stats.requests_failed as i64,
daemon.state.final_stats.requests_in_flight as i32,
)
.execute(self.pools.write())
.await
.map_err(|e| FusilladeError::Other(anyhow!("Failed to persist daemon: {}", e)))?;
}
}
Ok(())
}
async fn get_daemon(&self, daemon_id: DaemonId) -> Result<AnyDaemonRecord> {
let row = sqlx::query!(
r#"
SELECT
id, status, hostname, pid, version, config_snapshot,
started_at, last_heartbeat, stopped_at,
requests_processed, requests_failed, requests_in_flight
FROM daemons
WHERE id = $1
"#,
*daemon_id as Uuid,
)
.fetch_one(self.pools.read())
.await
.map_err(|e| match e {
sqlx::Error::RowNotFound => FusilladeError::Other(anyhow!("Daemon not found")),
_ => FusilladeError::Other(anyhow!("Failed to fetch daemon: {}", e)),
})?;
let data = DaemonData {
id: DaemonId(row.id),
hostname: row.hostname,
pid: row.pid,
version: row.version,
config_snapshot: row.config_snapshot,
};
let any_daemon = match row.status.as_str() {
"initializing" => AnyDaemonRecord::Initializing(DaemonRecord {
data,
state: Initializing {
started_at: row.started_at,
},
}),
"running" => AnyDaemonRecord::Running(DaemonRecord {
data,
state: Running {
started_at: row.started_at,
last_heartbeat: row.last_heartbeat.ok_or_else(|| {
FusilladeError::Other(anyhow!("Running daemon missing last_heartbeat"))
})?,
stats: crate::daemon::DaemonStats {
requests_processed: row.requests_processed as u64,
requests_failed: row.requests_failed as u64,
requests_in_flight: row.requests_in_flight as usize,
},
},
}),
"dead" => AnyDaemonRecord::Dead(DaemonRecord {
data,
state: Dead {
started_at: row.started_at,
stopped_at: row.stopped_at.ok_or_else(|| {
FusilladeError::Other(anyhow!("Dead daemon missing stopped_at"))
})?,
final_stats: crate::daemon::DaemonStats {
requests_processed: row.requests_processed as u64,
requests_failed: row.requests_failed as u64,
requests_in_flight: row.requests_in_flight as usize,
},
},
}),
_ => {
return Err(FusilladeError::Other(anyhow!(
"Unknown daemon status: {}",
row.status
)));
}
};
Ok(any_daemon)
}
async fn list_daemons(
&self,
status_filter: Option<DaemonStatus>,
) -> Result<Vec<AnyDaemonRecord>> {
let status_str = status_filter.as_ref().map(|s| s.as_str());
let rows = sqlx::query!(
r#"
SELECT
id, status, hostname, pid, version, config_snapshot,
started_at, last_heartbeat, stopped_at,
requests_processed, requests_failed, requests_in_flight
FROM daemons
WHERE ($1::text IS NULL OR status = $1)
ORDER BY created_at DESC
"#,
status_str,
)
.fetch_all(self.pools.read())
.await
.map_err(|e| FusilladeError::Other(anyhow!("Failed to list daemons: {}", e)))?;
let mut daemons = Vec::new();
for row in rows {
let data = DaemonData {
id: DaemonId(row.id),
hostname: row.hostname,
pid: row.pid,
version: row.version,
config_snapshot: row.config_snapshot,
};
let any_daemon = match row.status.as_str() {
"initializing" => AnyDaemonRecord::Initializing(DaemonRecord {
data,
state: Initializing {
started_at: row.started_at,
},
}),
"running" => {
if let Some(last_heartbeat) = row.last_heartbeat {
AnyDaemonRecord::Running(DaemonRecord {
data,
state: Running {
started_at: row.started_at,
last_heartbeat,
stats: crate::daemon::DaemonStats {
requests_processed: row.requests_processed as u64,
requests_failed: row.requests_failed as u64,
requests_in_flight: row.requests_in_flight as usize,
},
},
})
} else {
continue;
}
}
"dead" => {
if let Some(stopped_at) = row.stopped_at {
AnyDaemonRecord::Dead(DaemonRecord {
data,
state: Dead {
started_at: row.started_at,
stopped_at,
final_stats: crate::daemon::DaemonStats {
requests_processed: row.requests_processed as u64,
requests_failed: row.requests_failed as u64,
requests_in_flight: row.requests_in_flight as usize,
},
},
})
} else {
continue;
}
}
_ => {
continue;
}
};
daemons.push(any_daemon);
}
Ok(daemons)
}
async fn purge_orphaned_rows(&self, batch_size: i64) -> Result<u64> {
let requests_deleted = sqlx::query!(
r#"
DELETE FROM requests
WHERE id IN (
SELECT r.id
FROM (SELECT id FROM batches WHERE deleted_at IS NOT NULL) b,
LATERAL (
SELECT id FROM requests
WHERE batch_id = b.id
LIMIT $1
FOR UPDATE SKIP LOCKED
) r
LIMIT $1
)
"#,
batch_size,
)
.execute(self.pools.write())
.await
.map_err(|e| FusilladeError::Other(anyhow!("Failed to purge orphaned requests: {}", e)))?
.rows_affected() as i64;
let templates_deleted = sqlx::query!(
r#"
DELETE FROM request_templates
WHERE id IN (
SELECT rt.id
FROM (SELECT id FROM files WHERE deleted_at IS NOT NULL) f,
LATERAL (
SELECT id FROM request_templates
WHERE file_id = f.id
LIMIT $1
FOR UPDATE SKIP LOCKED
) rt
LIMIT $1
)
"#,
batch_size,
)
.execute(self.pools.write())
.await
.map_err(|e| {
FusilladeError::Other(anyhow!("Failed to purge orphaned request_templates: {}", e))
})?
.rows_affected() as i64;
let total = (requests_deleted + templates_deleted) as u64;
if total > 0 {
tracing::info!(requests_deleted, templates_deleted, "Purged orphaned rows");
}
Ok(total)
}
}
#[async_trait]
impl<P: PoolProvider, H: HttpClient + 'static> DaemonExecutor<H> for PostgresRequestManager<P, H> {
fn http_client(&self) -> &Arc<H> {
&self.http_client
}
fn config(&self) -> &DaemonConfig {
&self.config
}
fn run(
self: Arc<Self>,
shutdown_token: tokio_util::sync::CancellationToken,
) -> Result<JoinHandle<Result<()>>> {
tracing::info!("Starting PostgreSQL request manager daemon");
let daemon = Arc::new(Daemon::new(
self.clone(),
self.http_client.clone(),
self.config.clone(),
shutdown_token,
));
let handle = tokio::spawn(async move {
daemon.run().await
});
tracing::info!("Daemon spawned successfully");
Ok(handle)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::TestDbPools;
use crate::batch::FileStreamResult;
use crate::daemon::{
AnyDaemonRecord, DaemonData, DaemonRecord, DaemonStats, DaemonStatus, Dead, Initializing,
Running,
};
use crate::http::MockHttpClient;
use chrono::Timelike;
fn expect_stream_success(result: FileStreamResult) -> FileId {
match result {
FileStreamResult::Success(file_id) => file_id,
FileStreamResult::Aborted => panic!("Expected stream creation success, got abort"),
}
}
#[sqlx::test]
async fn test_create_and_get_file(pool: sqlx::PgPool) {
let http_client = Arc::new(MockHttpClient::new());
let manager = PostgresRequestManager::with_client(
TestDbPools::new(pool.clone()).await.unwrap(),
http_client,
);
let file_id = manager
.create_file(
"test-file".to_string(),
Some("A test file".to_string()),
vec![
RequestTemplateInput {
custom_id: None,
endpoint: "https://api.example.com".to_string(),
method: "POST".to_string(),
path: "/v1/completions".to_string(),
body: r#"{"model":"gpt-4"}"#.to_string(),
model: "gpt-4".to_string(),
api_key: "key1".to_string(),
},
RequestTemplateInput {
custom_id: None,
endpoint: "https://api.example.com".to_string(),
method: "POST".to_string(),
path: "/v1/completions".to_string(),
body: r#"{"model":"gpt-3.5"}"#.to_string(),
model: "gpt-3.5".to_string(),
api_key: "key2".to_string(),
},
],
)
.await
.expect("Failed to create file");
let file = manager.get_file(file_id).await.expect("Failed to get file");
assert_eq!(file.id, file_id);
assert_eq!(file.name, "test-file");
assert_eq!(file.description, Some("A test file".to_string()));
let content = manager
.get_file_content(file_id)
.await
.expect("Failed to get content");
assert_eq!(content.len(), 2);
match &content[0] {
FileContentItem::Template(t) => assert_eq!(t.model, "gpt-4"),
_ => panic!("Expected template"),
}
match &content[1] {
FileContentItem::Template(t) => assert_eq!(t.model, "gpt-3.5"),
_ => panic!("Expected template"),
}
}
#[sqlx::test]
async fn test_batched_insert_small_file(pool: sqlx::PgPool) {
let http_client = Arc::new(MockHttpClient::new());
let manager = PostgresRequestManager::with_client(
TestDbPools::new(pool.clone()).await.unwrap(),
http_client,
);
let templates: Vec<RequestTemplateInput> = (0..10)
.map(|i| RequestTemplateInput {
custom_id: Some(format!("batch-{}", i)),
endpoint: "https://api.example.com".to_string(),
method: "POST".to_string(),
path: "/v1/completions".to_string(),
body: format!(r#"{{"prompt":"test {}"}}"#, i),
model: "gpt-4".to_string(),
api_key: "key".to_string(),
})
.collect();
let file_id = manager
.create_file(
"batched-small".to_string(),
Some("Small file for batched insert test".to_string()),
templates,
)
.await
.expect("Failed to create file");
let content = manager
.get_file_content(file_id)
.await
.expect("Failed to get content");
assert_eq!(content.len(), 10, "Should have 10 templates");
for (i, item) in content.iter().enumerate() {
match item {
FileContentItem::Template(t) => {
assert_eq!(t.custom_id, Some(format!("batch-{}", i)));
assert_eq!(t.body, format!(r#"{{"prompt":"test {}"}}"#, i));
}
_ => panic!("Expected template"),
}
}
}
#[sqlx::test]
async fn test_batched_insert_large_file(pool: sqlx::PgPool) {
let http_client = Arc::new(MockHttpClient::new());
let manager = PostgresRequestManager::with_client(
TestDbPools::new(pool.clone()).await.unwrap(),
http_client,
);
let template_count = 15_000;
let templates: Vec<RequestTemplateInput> = (0..template_count)
.map(|i| RequestTemplateInput {
custom_id: Some(format!("large-batch-{}", i)),
endpoint: "https://api.example.com".to_string(),
method: "POST".to_string(),
path: "/v1/completions".to_string(),
body: format!(r#"{{"prompt":"test {}","data":{}}}"#, i, "x".repeat(100)),
model: "gpt-4".to_string(),
api_key: "key".to_string(),
})
.collect();
let file_id = manager
.create_file(
"batched-large".to_string(),
Some("Large file for batched insert test".to_string()),
templates,
)
.await
.expect("Failed to create file");
let content = manager
.get_file_content(file_id)
.await
.expect("Failed to get content");
assert_eq!(
content.len(),
template_count,
"Should have {} templates",
template_count
);
match &content[0] {
FileContentItem::Template(t) => {
assert_eq!(t.custom_id, Some("large-batch-0".to_string()));
}
_ => panic!("Expected template"),
}
match &content[7500] {
FileContentItem::Template(t) => {
assert_eq!(t.custom_id, Some("large-batch-7500".to_string()));
}
_ => panic!("Expected template"),
}
match &content[14999] {
FileContentItem::Template(t) => {
assert_eq!(t.custom_id, Some("large-batch-14999".to_string()));
}
_ => panic!("Expected template"),
}
}
#[sqlx::test]
async fn test_batched_insert_preserves_line_numbers(pool: sqlx::PgPool) {
let http_client = Arc::new(MockHttpClient::new());
let manager = PostgresRequestManager::with_client(
TestDbPools::new(pool.clone()).await.unwrap(),
http_client,
)
.with_batch_insert_strategy(BatchInsertStrategy::Batched { batch_size: 50 });
let template_count = 150;
let templates: Vec<RequestTemplateInput> = (0..template_count)
.map(|i| RequestTemplateInput {
custom_id: Some(format!("line-{}", i)),
endpoint: "https://api.openai.com/v1".to_string(),
method: "POST".to_string(),
path: "/chat/completions".to_string(),
body: format!(
r#"{{"model":"gpt-4","messages":[{{"role":"user","content":"line {}"}}]}}"#,
i
),
model: "gpt-4".to_string(),
api_key: "test-key".to_string(),
})
.collect();
let file_id = manager
.create_file("test-batched-lines".to_string(), None, templates)
.await
.expect("Failed to create file");
let rows = sqlx::query!(
r#"
SELECT custom_id, line_number
FROM request_templates
WHERE file_id = $1
ORDER BY line_number
"#,
*file_id as Uuid,
)
.fetch_all(&pool)
.await
.expect("Failed to query templates");
assert_eq!(rows.len(), template_count);
for (i, row) in rows.iter().enumerate() {
assert_eq!(
row.line_number, i as i32,
"Line number {} should be sequential",
i
);
assert_eq!(row.custom_id.as_ref().unwrap(), &format!("line-{}", i));
}
}
#[sqlx::test]
async fn test_batched_insert_with_stream(pool: sqlx::PgPool) {
use crate::batch::{FileMetadata, FileStreamItem};
use futures::stream;
let http_client = Arc::new(MockHttpClient::new());
let manager = PostgresRequestManager::with_client(
TestDbPools::new(pool.clone()).await.unwrap(),
http_client,
);
let mut items = vec![FileStreamItem::Metadata(FileMetadata {
filename: Some("streamed-batched".to_string()),
description: Some("Batched insert via stream".to_string()),
purpose: None,
expires_after_anchor: None,
expires_after_seconds: None,
size_bytes: None,
uploaded_by: Some("test-user".to_string()),
api_key_id: None,
})];
for i in 0..8000 {
items.push(FileStreamItem::Template(RequestTemplateInput {
custom_id: Some(format!("stream-batch-{}", i)),
endpoint: "https://api.example.com".to_string(),
method: "POST".to_string(),
path: "/v1/completions".to_string(),
body: format!(r#"{{"n":{}}}"#, i),
model: "gpt-4".to_string(),
api_key: "key".to_string(),
}));
}
let stream = stream::iter(items);
let file_id = expect_stream_success(
manager
.create_file_stream(stream)
.await
.expect("Failed to create file from stream"),
);
let file = manager.get_file(file_id).await.expect("Failed to get file");
assert_eq!(file.name, "streamed-batched");
let content = manager
.get_file_content(file_id)
.await
.expect("Failed to get content");
assert_eq!(content.len(), 8000);
match &content[4999] {
FileContentItem::Template(t) => {
assert_eq!(t.custom_id, Some("stream-batch-4999".to_string()));
}
_ => panic!("Expected template"),
}
match &content[5000] {
FileContentItem::Template(t) => {
assert_eq!(t.custom_id, Some("stream-batch-5000".to_string()));
}
_ => panic!("Expected template"),
}
}
#[sqlx::test]
async fn test_batched_insert_empty_batches(pool: sqlx::PgPool) {
use crate::batch::{FileMetadata, FileStreamItem};
use futures::stream;
let http_client = Arc::new(MockHttpClient::new());
let manager = PostgresRequestManager::with_client(
TestDbPools::new(pool.clone()).await.unwrap(),
http_client,
);
let items = vec![FileStreamItem::Metadata(FileMetadata {
filename: Some("empty-file".to_string()),
description: Some("File with no templates".to_string()),
purpose: None,
expires_after_anchor: None,
expires_after_seconds: None,
size_bytes: None,
uploaded_by: None,
api_key_id: None,
})];
let stream = stream::iter(items);
let file_id = expect_stream_success(
manager
.create_file_stream(stream)
.await
.expect("Failed to create empty file"),
);
let file = manager.get_file(file_id).await.expect("Failed to get file");
assert_eq!(file.name, "empty-file");
let content = manager
.get_file_content(file_id)
.await
.expect("Failed to get content");
assert_eq!(content.len(), 0, "Should have no templates");
}
#[sqlx::test]
#[allow(deprecated)]
async fn test_batched_insert_transactional_rollback(pool: sqlx::PgPool) {
use crate::batch::{FileMetadata, FileStreamItem};
use futures::stream;
let http_client = Arc::new(MockHttpClient::new());
let manager = PostgresRequestManager::with_client(
TestDbPools::new(pool.clone()).await.unwrap(),
http_client,
);
let mut items = vec![FileStreamItem::Metadata(FileMetadata {
filename: Some("rollback-test".to_string()),
description: Some("Should rollback on error".to_string()),
purpose: None,
expires_after_anchor: None,
expires_after_seconds: None,
size_bytes: None,
uploaded_by: None,
api_key_id: None,
})];
for i in 0..3000 {
items.push(FileStreamItem::Template(RequestTemplateInput {
custom_id: Some(format!("rollback-{}", i)),
endpoint: "https://api.example.com".to_string(),
method: "POST".to_string(),
path: "/v1/completions".to_string(),
body: format!(r#"{{"n":{}}}"#, i),
model: "gpt-4".to_string(),
api_key: "key".to_string(),
}));
}
items.push(FileStreamItem::Error("Simulated parse error".to_string()));
for i in 3000..3100 {
items.push(FileStreamItem::Template(RequestTemplateInput {
custom_id: Some(format!("rollback-{}", i)),
endpoint: "https://api.example.com".to_string(),
method: "POST".to_string(),
path: "/v1/completions".to_string(),
body: format!(r#"{{"n":{}}}"#, i),
model: "gpt-4".to_string(),
api_key: "key".to_string(),
}));
}
let stream = stream::iter(items);
let result = manager.create_file_stream(stream).await;
assert!(result.is_err());
let files =
sqlx::query!(r#"SELECT COUNT(*) as count FROM files WHERE name = 'rollback-test'"#)
.fetch_one(&pool)
.await
.unwrap();
assert_eq!(files.count, Some(0), "File should not exist after rollback");
}
#[sqlx::test]
async fn test_batched_insert_body_byte_size_calculation(pool: sqlx::PgPool) {
let http_client = Arc::new(MockHttpClient::new());
let manager = PostgresRequestManager::with_client(
TestDbPools::new(pool.clone()).await.unwrap(),
http_client,
);
let templates = vec![
RequestTemplateInput {
custom_id: Some("small".to_string()),
endpoint: "https://api.example.com".to_string(),
method: "POST".to_string(),
path: "/v1/completions".to_string(),
body: r#"{"a":1}"#.to_string(), model: "gpt-4".to_string(),
api_key: "key".to_string(),
},
RequestTemplateInput {
custom_id: Some("large".to_string()),
endpoint: "https://api.example.com".to_string(),
method: "POST".to_string(),
path: "/v1/completions".to_string(),
body: format!(r#"{{"data":"{}"}}"#, "x".repeat(5000)), model: "gpt-4".to_string(),
api_key: "key".to_string(),
},
];
let file_id = manager
.create_file("byte-size-test".to_string(), None, templates)
.await
.expect("Failed to create file");
let rows = sqlx::query!(
r#"
SELECT custom_id, body_byte_size, LENGTH(body) as actual_length
FROM request_templates
WHERE file_id = $1
ORDER BY line_number ASC
"#,
*file_id as Uuid,
)
.fetch_all(&pool)
.await
.unwrap();
assert_eq!(rows.len(), 2);
assert_eq!(rows[0].custom_id, Some("small".to_string()));
assert_eq!(rows[0].body_byte_size, 7);
assert_eq!(rows[0].actual_length, Some(7));
assert_eq!(rows[1].custom_id, Some("large".to_string()));
assert_eq!(
rows[1].body_byte_size,
rows[1].actual_length.unwrap() as i64
);
assert!(rows[1].body_byte_size > 5000);
}
#[sqlx::test]
async fn test_batched_insert_performance_comparison(pool: sqlx::PgPool) {
use std::time::Instant;
let http_client = Arc::new(MockHttpClient::new());
let manager_batched = PostgresRequestManager::with_client(
TestDbPools::new(pool.clone()).await.unwrap(),
http_client.clone(),
);
let template_count = 1000;
let templates: Vec<RequestTemplateInput> = (0..template_count)
.map(|i| RequestTemplateInput {
custom_id: Some(format!("perf-{}", i)),
endpoint: "https://api.example.com".to_string(),
method: "POST".to_string(),
path: "/v1/completions".to_string(),
body: format!(r#"{{"prompt":"test {}","data":{}}}"#, i, "x".repeat(50)),
model: "gpt-4".to_string(),
api_key: "key".to_string(),
})
.collect();
let start = Instant::now();
let file_id = manager_batched
.create_file("perf-test-batched".to_string(), None, templates)
.await
.expect("Failed to create file");
let batched_duration = start.elapsed();
let content = manager_batched
.get_file_content(file_id)
.await
.expect("Failed to get content");
assert_eq!(content.len(), template_count);
println!(
"Batched insert of {} templates took: {:?}",
template_count, batched_duration
);
assert!(
batched_duration.as_secs() < 2,
"Batched insert should be fast"
);
}
#[sqlx::test]
#[should_panic(expected = "batch_size must be greater than 0")]
async fn test_batched_insert_rejects_zero_batch_size(pool: sqlx::PgPool) {
let http_client = Arc::new(MockHttpClient::new());
let _manager =
PostgresRequestManager::with_client(TestDbPools::new(pool).await.unwrap(), http_client)
.with_batch_insert_strategy(BatchInsertStrategy::Batched { batch_size: 0 });
}
#[sqlx::test]
async fn test_batched_insert_valid_batch_sizes(pool: sqlx::PgPool) {
let http_client = Arc::new(MockHttpClient::new());
let manager1 = PostgresRequestManager::with_client(
TestDbPools::new(pool.clone()).await.unwrap(),
http_client.clone(),
)
.with_batch_insert_strategy(BatchInsertStrategy::Batched { batch_size: 1 });
let templates: Vec<RequestTemplateInput> = (0..5)
.map(|i| RequestTemplateInput {
custom_id: Some(format!("req-{}", i)),
endpoint: "/v1/chat/completions".to_string(),
method: "POST".to_string(),
path: "/v1/chat/completions".to_string(),
body: format!(
r#"{{"model":"gpt-4","messages":[{{"role":"user","content":"test {}"}}]}}"#,
i
),
model: "gpt-4".to_string(),
api_key: "test-key".to_string(),
})
.collect();
let file_id1 = manager1
.create_file("batch-size-1".to_string(), None, templates.clone())
.await
.expect("Failed to create file with batch_size=1");
let content1 = manager1
.get_file_content(file_id1)
.await
.expect("Failed to get content");
assert_eq!(
content1.len(),
5,
"Should have 5 templates with batch_size=1"
);
let manager100 = PostgresRequestManager::with_client(
TestDbPools::new(pool.clone()).await.unwrap(),
http_client.clone(),
)
.with_batch_insert_strategy(BatchInsertStrategy::Batched { batch_size: 100 });
let file_id100 = manager100
.create_file("batch-size-100".to_string(), None, templates.clone())
.await
.expect("Failed to create file with batch_size=100");
let content100 = manager100
.get_file_content(file_id100)
.await
.expect("Failed to get content");
assert_eq!(
content100.len(),
5,
"Should have 5 templates with batch_size=100"
);
let manager5000 =
PostgresRequestManager::with_client(TestDbPools::new(pool).await.unwrap(), http_client)
.with_batch_insert_strategy(BatchInsertStrategy::Batched { batch_size: 5000 });
let file_id5000 = manager5000
.create_file("batch-size-5000".to_string(), None, templates)
.await
.expect("Failed to create file with batch_size=5000");
let content5000 = manager5000
.get_file_content(file_id5000)
.await
.expect("Failed to get content");
assert_eq!(
content5000.len(),
5,
"Should have 5 templates with batch_size=5000"
);
}
#[sqlx::test]
async fn test_create_batch_and_get_status(pool: sqlx::PgPool) {
let http_client = Arc::new(MockHttpClient::new());
let manager = PostgresRequestManager::with_client(
TestDbPools::new(pool.clone()).await.unwrap(),
http_client,
);
let file_id = manager
.create_file(
"batch-test".to_string(),
None,
vec![
RequestTemplateInput {
custom_id: None,
endpoint: "https://api.example.com".to_string(),
method: "POST".to_string(),
path: "/v1/test".to_string(),
body: r#"{"prompt":"1"}"#.to_string(),
model: "gpt-4".to_string(),
api_key: "key".to_string(),
},
RequestTemplateInput {
custom_id: None,
endpoint: "https://api.example.com".to_string(),
method: "POST".to_string(),
path: "/v1/test".to_string(),
body: r#"{"prompt":"2"}"#.to_string(),
model: "gpt-4".to_string(),
api_key: "key".to_string(),
},
RequestTemplateInput {
custom_id: None,
endpoint: "https://api.example.com".to_string(),
method: "POST".to_string(),
path: "/v1/test".to_string(),
body: r#"{"prompt":"3"}"#.to_string(),
model: "gpt-4".to_string(),
api_key: "key".to_string(),
},
],
)
.await
.expect("Failed to create file");
let batch = manager
.create_batch(crate::batch::BatchInput {
file_id,
endpoint: "/v1/chat/completions".to_string(),
completion_window: "24h".to_string(),
metadata: None,
created_by: None,
api_key_id: None,
api_key: None,
total_requests: None,
})
.await
.expect("Failed to create batch");
let status = manager
.get_batch_status(batch.id)
.await
.expect("Failed to get batch status");
assert_eq!(status.batch_id, batch.id);
assert_eq!(status.file_id, Some(file_id));
assert_eq!(status.file_name, Some("batch-test".to_string()));
assert_eq!(status.total_requests, 3);
assert_eq!(status.pending_requests, 3);
assert_eq!(status.completed_requests, 0);
assert_eq!(status.failed_requests, 0);
let requests = manager
.get_batch_requests(batch.id)
.await
.expect("Failed to get batch requests");
assert_eq!(requests.len(), 3);
for request in requests {
assert!(request.is_pending());
}
}
#[sqlx::test]
async fn test_claim_requests(pool: sqlx::PgPool) {
let http_client = Arc::new(MockHttpClient::new());
let manager = PostgresRequestManager::with_client(
TestDbPools::new(pool.clone()).await.unwrap(),
http_client,
);
let file_id = manager
.create_file(
"claim-test".to_string(),
None,
(0..5)
.map(|i| RequestTemplateInput {
custom_id: None,
endpoint: "https://api.example.com".to_string(),
method: "POST".to_string(),
path: "/test".to_string(),
body: format!(r#"{{"n":{}}}"#, i),
model: "test".to_string(),
api_key: "key".to_string(),
})
.collect(),
)
.await
.unwrap();
let batch = manager
.create_batch(crate::batch::BatchInput {
file_id,
endpoint: "/v1/chat/completions".to_string(),
completion_window: "24h".to_string(),
metadata: None,
created_by: None,
api_key_id: None,
api_key: None,
total_requests: None,
})
.await
.unwrap();
let daemon_id = DaemonId::from(Uuid::new_v4());
let capacity = HashMap::from([("test".to_string(), 10)]);
let claimed = manager
.claim_requests(3, daemon_id, &capacity, &HashMap::new())
.await
.expect("Failed to claim requests");
assert_eq!(claimed.len(), 3);
for request in &claimed {
assert_eq!(request.state.daemon_id, daemon_id);
assert_eq!(request.state.retry_attempt, 0);
}
let claimed2 = manager
.claim_requests(10, daemon_id, &capacity, &HashMap::new())
.await
.expect("Failed to claim requests");
assert_eq!(claimed2.len(), 2);
let status = manager.get_batch_status(batch.id).await.unwrap();
assert_eq!(status.total_requests, 5);
assert_eq!(status.pending_requests, 0);
assert_eq!(status.in_progress_requests, 5); }
#[sqlx::test]
async fn test_cancel_batch(pool: sqlx::PgPool) {
let http_client = Arc::new(MockHttpClient::new());
let manager = PostgresRequestManager::with_client(
TestDbPools::new(pool.clone()).await.unwrap(),
http_client,
);
let file_id = manager
.create_file(
"cancel-test".to_string(),
None,
(0..3)
.map(|i| RequestTemplateInput {
custom_id: None,
endpoint: "https://api.example.com".to_string(),
method: "POST".to_string(),
path: "/test".to_string(),
body: format!(r#"{{"n":{}}}"#, i),
model: "test".to_string(),
api_key: "key".to_string(),
})
.collect(),
)
.await
.unwrap();
let batch = manager
.create_batch(crate::batch::BatchInput {
file_id,
endpoint: "/v1/chat/completions".to_string(),
completion_window: "24h".to_string(),
metadata: None,
created_by: None,
api_key_id: None,
api_key: None,
total_requests: None,
})
.await
.unwrap();
let status_before = manager.get_batch_status(batch.id).await.unwrap();
assert_eq!(status_before.pending_requests, 3);
assert_eq!(status_before.canceled_requests, 0);
manager.cancel_batch(batch.id).await.unwrap();
let status_after = manager.get_batch_status(batch.id).await.unwrap();
assert_eq!(status_after.pending_requests, 0);
assert_eq!(status_after.canceled_requests, 3);
let requests = manager.get_batch_requests(batch.id).await.unwrap();
assert_eq!(requests.len(), 3);
for request in requests {
assert!(matches!(request, AnyRequest::Pending(_)));
}
}
#[sqlx::test]
async fn test_delete_batch(pool: sqlx::PgPool) {
let http_client = Arc::new(MockHttpClient::new());
let manager = PostgresRequestManager::with_client(
TestDbPools::new(pool.clone()).await.unwrap(),
http_client,
);
let file_id = manager
.create_file(
"delete-batch-test".to_string(),
None,
vec![
RequestTemplateInput {
custom_id: None,
endpoint: "https://api.example.com".to_string(),
method: "POST".to_string(),
path: "/test".to_string(),
body: r#"{"n":1}"#.to_string(),
model: "test".to_string(),
api_key: "key".to_string(),
},
RequestTemplateInput {
custom_id: None,
endpoint: "https://api.example.com".to_string(),
method: "POST".to_string(),
path: "/test".to_string(),
body: r#"{"n":2}"#.to_string(),
model: "test".to_string(),
api_key: "key".to_string(),
},
],
)
.await
.unwrap();
let batch = manager
.create_batch(crate::batch::BatchInput {
file_id,
endpoint: "/v1/chat/completions".to_string(),
completion_window: "24h".to_string(),
metadata: None,
created_by: None,
api_key_id: None,
api_key: None,
total_requests: None,
})
.await
.unwrap();
let batch_before = manager.get_batch(batch.id).await;
assert!(batch_before.is_ok());
let requests_before = manager.get_batch_requests(batch.id).await.unwrap();
assert_eq!(requests_before.len(), 2);
manager.delete_batch(batch.id).await.unwrap();
let batch_after = manager.get_batch(batch.id).await;
assert!(batch_after.is_err());
let requests_after = manager.get_batch_requests(batch.id).await.unwrap();
assert_eq!(requests_after.len(), 0);
let file_after = manager.get_file(file_id).await;
assert!(file_after.is_ok());
let delete_result = manager.delete_batch(batch.id).await;
assert!(delete_result.is_err());
}
#[sqlx::test]
async fn test_cancel_individual_requests(pool: sqlx::PgPool) {
let http_client = Arc::new(MockHttpClient::new());
let manager = PostgresRequestManager::with_client(
TestDbPools::new(pool.clone()).await.unwrap(),
http_client,
);
let file_id = manager
.create_file(
"individual-cancel-test".to_string(),
None,
(0..5)
.map(|i| RequestTemplateInput {
custom_id: None,
endpoint: "https://api.example.com".to_string(),
method: "POST".to_string(),
path: "/test".to_string(),
body: format!(r#"{{"n":{}}}"#, i),
model: "test".to_string(),
api_key: "key".to_string(),
})
.collect(),
)
.await
.unwrap();
let batch = manager
.create_batch(crate::batch::BatchInput {
file_id,
endpoint: "/v1/chat/completions".to_string(),
completion_window: "24h".to_string(),
metadata: None,
created_by: None,
api_key_id: None,
api_key: None,
total_requests: None,
})
.await
.unwrap();
let requests = manager.get_batch_requests(batch.id).await.unwrap();
let request_ids: Vec<_> = requests.iter().map(|r| r.id()).collect();
let results = manager
.cancel_requests(request_ids[0..3].to_vec())
.await
.unwrap();
for result in results {
assert!(result.is_ok());
}
let status = manager.get_batch_status(batch.id).await.unwrap();
assert_eq!(status.pending_requests, 2);
assert_eq!(status.canceled_requests, 3);
let all_requests = manager.get_batch_requests(batch.id).await.unwrap();
let canceled_count = all_requests
.iter()
.filter(|r| matches!(r, AnyRequest::Canceled(_)))
.count();
assert_eq!(canceled_count, 3);
}
#[sqlx::test]
async fn test_list_files(pool: sqlx::PgPool) {
let http_client = Arc::new(MockHttpClient::new());
let manager = PostgresRequestManager::with_client(
TestDbPools::new(pool.clone()).await.unwrap(),
http_client,
);
let file1_id = manager
.create_file("file1".to_string(), Some("First".to_string()), vec![])
.await
.unwrap();
let file2_id = manager
.create_file("file2".to_string(), Some("Second".to_string()), vec![])
.await
.unwrap();
let file3_id = manager
.create_file("file3".to_string(), None, vec![])
.await
.unwrap();
let files = manager
.list_files(crate::batch::FileFilter::default())
.await
.unwrap();
assert!(files.len() >= 3);
let file_ids: Vec<_> = files.iter().map(|f| f.id).collect();
assert!(file_ids.contains(&file1_id));
assert!(file_ids.contains(&file2_id));
assert!(file_ids.contains(&file3_id));
let file1 = files.iter().find(|f| f.id == file1_id).unwrap();
assert_eq!(file1.name, "file1");
assert_eq!(file1.description, Some("First".to_string()));
let file3 = files.iter().find(|f| f.id == file3_id).unwrap();
assert_eq!(file3.name, "file3");
assert_eq!(file3.description, None);
}
#[sqlx::test]
async fn test_list_file_batches(pool: sqlx::PgPool) {
let http_client = Arc::new(MockHttpClient::new());
let manager = PostgresRequestManager::with_client(
TestDbPools::new(pool.clone()).await.unwrap(),
http_client,
);
let file_id = manager
.create_file(
"batch-list-test".to_string(),
None,
vec![RequestTemplateInput {
custom_id: None,
endpoint: "https://api.example.com".to_string(),
method: "POST".to_string(),
path: "/test".to_string(),
body: "{}".to_string(),
model: "test".to_string(),
api_key: "key".to_string(),
}],
)
.await
.unwrap();
let batch1 = manager
.create_batch(crate::batch::BatchInput {
file_id,
endpoint: "/v1/chat/completions".to_string(),
completion_window: "24h".to_string(),
metadata: None,
created_by: None,
api_key_id: None,
api_key: None,
total_requests: None,
})
.await
.unwrap();
let batch2 = manager
.create_batch(crate::batch::BatchInput {
file_id,
endpoint: "/v1/chat/completions".to_string(),
completion_window: "24h".to_string(),
metadata: None,
created_by: None,
api_key_id: None,
api_key: None,
total_requests: None,
})
.await
.unwrap();
let batch3 = manager
.create_batch(crate::batch::BatchInput {
file_id,
endpoint: "/v1/chat/completions".to_string(),
completion_window: "24h".to_string(),
metadata: None,
created_by: None,
api_key_id: None,
api_key: None,
total_requests: None,
})
.await
.unwrap();
let batches = manager.list_file_batches(file_id).await.unwrap();
assert_eq!(batches.len(), 3);
let batch_ids: Vec<_> = batches.iter().map(|b| b.batch_id).collect();
assert!(batch_ids.contains(&batch1.id));
assert!(batch_ids.contains(&batch2.id));
assert!(batch_ids.contains(&batch3.id));
for batch in batches {
assert_eq!(batch.total_requests, 1);
assert_eq!(batch.pending_requests, 1);
}
}
#[sqlx::test]
async fn test_delete_file_cascade(pool: sqlx::PgPool) {
let http_client = Arc::new(MockHttpClient::new());
let manager = PostgresRequestManager::with_client(
TestDbPools::new(pool.clone()).await.unwrap(),
http_client,
);
let file_id = manager
.create_file(
"delete-test".to_string(),
None,
vec![
RequestTemplateInput {
custom_id: None,
endpoint: "https://api.example.com".to_string(),
method: "POST".to_string(),
path: "/test".to_string(),
body: r#"{"n":1}"#.to_string(),
model: "test".to_string(),
api_key: "key".to_string(),
},
RequestTemplateInput {
custom_id: None,
endpoint: "https://api.example.com".to_string(),
method: "POST".to_string(),
path: "/test".to_string(),
body: r#"{"n":2}"#.to_string(),
model: "test".to_string(),
api_key: "key".to_string(),
},
],
)
.await
.unwrap();
let batch = manager
.create_batch(crate::batch::BatchInput {
file_id,
endpoint: "/v1/chat/completions".to_string(),
completion_window: "24h".to_string(),
metadata: None,
created_by: None,
api_key_id: None,
api_key: None,
total_requests: None,
})
.await
.unwrap();
let batch_before = manager.get_batch(batch.id).await.unwrap();
assert_eq!(batch_before.file_id, Some(file_id));
assert!(batch_before.cancelling_at.is_none());
assert!(batch_before.cancelled_at.is_none());
assert_eq!(batch_before.pending_requests, 2);
let requests_before = manager.get_batch_requests(batch.id).await.unwrap();
assert_eq!(requests_before.len(), 2);
manager.delete_file(file_id).await.unwrap();
let file_result = manager.get_file(file_id).await;
assert!(file_result.is_err());
let batch_after = manager.get_batch(batch.id).await.unwrap();
assert_eq!(batch_after.file_id, None);
assert!(batch_after.cancelling_at.is_some());
assert!(batch_after.cancelled_at.is_some());
assert_eq!(batch_after.canceled_requests, 2);
let requests_after = manager.get_batch_requests(batch.id).await.unwrap();
assert_eq!(requests_after.len(), 0); }
#[sqlx::test]
async fn test_unclaim_stale_claimed_requests(pool: sqlx::PgPool) {
let http_client = Arc::new(MockHttpClient::new());
let config = crate::daemon::DaemonConfig {
claim_timeout_ms: 1000, processing_timeout_ms: 60000, ..Default::default()
};
let manager = Arc::new(
PostgresRequestManager::with_client(
TestDbPools::new(pool.clone()).await.unwrap(),
http_client,
)
.with_config(config),
);
let file_id = manager
.create_file(
"stale-test".to_string(),
None,
vec![RequestTemplateInput {
custom_id: None,
endpoint: "https://api.example.com".to_string(),
method: "POST".to_string(),
path: "/test".to_string(),
body: "{}".to_string(),
model: "test".to_string(),
api_key: "key".to_string(),
}],
)
.await
.unwrap();
let batch = manager
.create_batch(crate::batch::BatchInput {
file_id,
endpoint: "/v1/chat/completions".to_string(),
completion_window: "24h".to_string(),
metadata: None,
created_by: None,
api_key_id: None,
api_key: None,
total_requests: None,
})
.await
.unwrap();
let capacity = HashMap::from([("test".to_string(), 10)]);
let daemon1_id = DaemonId::from(Uuid::new_v4());
let claimed = manager
.claim_requests(1, daemon1_id, &capacity, &HashMap::new())
.await
.unwrap();
assert_eq!(claimed.len(), 1);
let request_id = claimed[0].data.id;
sqlx::query!(
"UPDATE requests SET claimed_at = NOW() - INTERVAL '3 seconds' WHERE id = $1",
*request_id as Uuid
)
.execute(&pool)
.await
.unwrap();
let daemon2_id = DaemonId::from(Uuid::new_v4());
let reclaimed = manager
.claim_requests(1, daemon2_id, &capacity, &HashMap::new())
.await
.unwrap();
assert_eq!(reclaimed.len(), 1);
assert_eq!(reclaimed[0].data.id, request_id);
assert_eq!(reclaimed[0].state.daemon_id, daemon2_id);
let status = manager.get_batch_status(batch.id).await.unwrap();
assert_eq!(status.in_progress_requests, 1);
}
#[sqlx::test]
async fn test_unclaim_stale_processing_requests(pool: sqlx::PgPool) {
let http_client = Arc::new(MockHttpClient::new());
let config = crate::daemon::DaemonConfig {
claim_timeout_ms: 60000, processing_timeout_ms: 1000, ..Default::default()
};
let manager = Arc::new(
PostgresRequestManager::with_client(
TestDbPools::new(pool.clone()).await.unwrap(),
http_client,
)
.with_config(config),
);
let file_id = manager
.create_file(
"stale-processing-test".to_string(),
None,
vec![RequestTemplateInput {
custom_id: None,
endpoint: "https://api.example.com".to_string(),
method: "POST".to_string(),
path: "/test".to_string(),
body: "{}".to_string(),
model: "test".to_string(),
api_key: "key".to_string(),
}],
)
.await
.unwrap();
let batch = manager
.create_batch(crate::batch::BatchInput {
file_id,
endpoint: "/v1/chat/completions".to_string(),
completion_window: "24h".to_string(),
metadata: None,
created_by: None,
api_key_id: None,
api_key: None,
total_requests: None,
})
.await
.unwrap();
let capacity = HashMap::from([("test".to_string(), 10)]);
let daemon1_id = DaemonId::from(Uuid::new_v4());
let claimed = manager
.claim_requests(1, daemon1_id, &capacity, &HashMap::new())
.await
.unwrap();
assert_eq!(claimed.len(), 1);
let request_id = claimed[0].data.id;
sqlx::query!(
r#"
UPDATE requests
SET
state = 'processing',
started_at = NOW() - INTERVAL '3 seconds'
WHERE id = $1
"#,
*request_id as Uuid
)
.execute(&pool)
.await
.unwrap();
let status_before = manager.get_batch_status(batch.id).await.unwrap();
assert_eq!(status_before.in_progress_requests, 1);
let daemon2_id = DaemonId::from(Uuid::new_v4());
let reclaimed = manager
.claim_requests(1, daemon2_id, &capacity, &HashMap::new())
.await
.unwrap();
assert_eq!(reclaimed.len(), 1);
assert_eq!(reclaimed[0].data.id, request_id);
assert_eq!(reclaimed[0].state.daemon_id, daemon2_id);
}
#[sqlx::test]
async fn test_unclaim_requests_from_dead_daemon(pool: sqlx::PgPool) {
let http_client = Arc::new(MockHttpClient::new());
let config = crate::daemon::DaemonConfig {
claim_timeout_ms: 600000,
processing_timeout_ms: 600000,
stale_daemon_threshold_ms: 1000, ..Default::default()
};
let manager = Arc::new(
PostgresRequestManager::with_client(
TestDbPools::new(pool.clone()).await.unwrap(),
http_client,
)
.with_config(config),
);
let file_id = manager
.create_file(
"dead-daemon-test".to_string(),
None,
vec![RequestTemplateInput {
custom_id: None,
endpoint: "https://api.example.com".to_string(),
method: "POST".to_string(),
path: "/test".to_string(),
body: "{}".to_string(),
model: "test".to_string(),
api_key: "key".to_string(),
}],
)
.await
.unwrap();
let batch = manager
.create_batch(crate::batch::BatchInput {
file_id,
endpoint: "/v1/chat/completions".to_string(),
completion_window: "24h".to_string(),
metadata: None,
created_by: None,
api_key_id: None,
api_key: None,
total_requests: None,
})
.await
.unwrap();
let daemon1_id = DaemonId::from(Uuid::new_v4());
let daemon1 = DaemonRecord {
data: DaemonData {
id: daemon1_id,
hostname: "test-host".to_string(),
pid: 1234,
version: "test".to_string(),
config_snapshot: serde_json::json!({}),
},
state: Dead {
started_at: Utc::now() - chrono::Duration::minutes(10),
stopped_at: Utc::now(),
final_stats: DaemonStats::default(),
},
};
manager.persist_daemon(&daemon1).await.unwrap();
let capacity = HashMap::from([("test".to_string(), 10)]);
let claimed = manager
.claim_requests(1, daemon1_id, &capacity, &HashMap::new())
.await
.unwrap();
assert_eq!(claimed.len(), 1);
let request_id = claimed[0].data.id;
sqlx::query!(
"UPDATE requests SET state = 'processing', started_at = NOW() WHERE id = $1",
*request_id as Uuid
)
.execute(&pool)
.await
.unwrap();
let status = manager.get_batch_status(batch.id).await.unwrap();
assert_eq!(status.in_progress_requests, 1);
let daemon2_id = DaemonId::from(Uuid::new_v4());
let reclaimed = manager
.claim_requests(1, daemon2_id, &capacity, &HashMap::new())
.await
.unwrap();
assert_eq!(reclaimed.len(), 1);
assert_eq!(reclaimed[0].data.id, request_id);
assert_eq!(reclaimed[0].state.daemon_id, daemon2_id);
}
#[sqlx::test]
async fn test_unclaim_requests_from_stale_heartbeat_daemon(pool: sqlx::PgPool) {
let http_client = Arc::new(MockHttpClient::new());
let config = crate::daemon::DaemonConfig {
claim_timeout_ms: 600000,
processing_timeout_ms: 600000,
stale_daemon_threshold_ms: 1000, ..Default::default()
};
let manager = Arc::new(
PostgresRequestManager::with_client(
TestDbPools::new(pool.clone()).await.unwrap(),
http_client,
)
.with_config(config),
);
let file_id = manager
.create_file(
"stale-heartbeat-test".to_string(),
None,
vec![RequestTemplateInput {
custom_id: None,
endpoint: "https://api.example.com".to_string(),
method: "POST".to_string(),
path: "/test".to_string(),
body: "{}".to_string(),
model: "test".to_string(),
api_key: "key".to_string(),
}],
)
.await
.unwrap();
let batch = manager
.create_batch(crate::batch::BatchInput {
file_id,
endpoint: "/v1/chat/completions".to_string(),
completion_window: "24h".to_string(),
metadata: None,
created_by: None,
api_key_id: None,
api_key: None,
total_requests: None,
})
.await
.unwrap();
let daemon1_id = DaemonId::from(Uuid::new_v4());
let daemon1 = DaemonRecord {
data: DaemonData {
id: daemon1_id,
hostname: "test-host".to_string(),
pid: 1234,
version: "test".to_string(),
config_snapshot: serde_json::json!({}),
},
state: Running {
started_at: Utc::now() - chrono::Duration::minutes(10),
last_heartbeat: Utc::now() - chrono::Duration::seconds(5), stats: DaemonStats::default(),
},
};
manager.persist_daemon(&daemon1).await.unwrap();
let capacity = HashMap::from([("test".to_string(), 10)]);
let claimed = manager
.claim_requests(1, daemon1_id, &capacity, &HashMap::new())
.await
.unwrap();
assert_eq!(claimed.len(), 1);
let request_id = claimed[0].data.id;
sqlx::query!(
"UPDATE requests SET state = 'processing', started_at = NOW() WHERE id = $1",
*request_id as Uuid
)
.execute(&pool)
.await
.unwrap();
let status = manager.get_batch_status(batch.id).await.unwrap();
assert_eq!(status.in_progress_requests, 1);
let daemon2_id = DaemonId::from(Uuid::new_v4());
let reclaimed = manager
.claim_requests(1, daemon2_id, &capacity, &HashMap::new())
.await
.unwrap();
assert_eq!(reclaimed.len(), 1);
assert_eq!(reclaimed[0].data.id, request_id);
assert_eq!(reclaimed[0].state.daemon_id, daemon2_id);
}
#[sqlx::test]
async fn test_dont_unclaim_requests_from_healthy_daemon(pool: sqlx::PgPool) {
let http_client = Arc::new(MockHttpClient::new());
let config = crate::daemon::DaemonConfig {
claim_timeout_ms: 600000,
processing_timeout_ms: 600000,
stale_daemon_threshold_ms: 60000, ..Default::default()
};
let manager = Arc::new(
PostgresRequestManager::with_client(
TestDbPools::new(pool.clone()).await.unwrap(),
http_client,
)
.with_config(config),
);
let file_id = manager
.create_file(
"healthy-daemon-test".to_string(),
None,
vec![
RequestTemplateInput {
custom_id: None,
endpoint: "https://api.example.com".to_string(),
method: "POST".to_string(),
path: "/test".to_string(),
body: r#"{"n":1}"#.to_string(),
model: "test".to_string(),
api_key: "key".to_string(),
},
RequestTemplateInput {
custom_id: None,
endpoint: "https://api.example.com".to_string(),
method: "POST".to_string(),
path: "/test".to_string(),
body: r#"{"n":2}"#.to_string(),
model: "test".to_string(),
api_key: "key".to_string(),
},
],
)
.await
.unwrap();
manager
.create_batch(crate::batch::BatchInput {
file_id,
endpoint: "/v1/chat/completions".to_string(),
completion_window: "24h".to_string(),
metadata: None,
created_by: None,
api_key_id: None,
api_key: None,
total_requests: None,
})
.await
.unwrap();
let daemon1_id = DaemonId::from(Uuid::new_v4());
let daemon1 = DaemonRecord {
data: DaemonData {
id: daemon1_id,
hostname: "test-host".to_string(),
pid: 1234,
version: "test".to_string(),
config_snapshot: serde_json::json!({}),
},
state: Running {
started_at: Utc::now() - chrono::Duration::minutes(10),
last_heartbeat: Utc::now(), stats: DaemonStats::default(),
},
};
manager.persist_daemon(&daemon1).await.unwrap();
let capacity = HashMap::from([("test".to_string(), 10)]);
let claimed = manager
.claim_requests(1, daemon1_id, &capacity, &HashMap::new())
.await
.unwrap();
assert_eq!(claimed.len(), 1);
let request_id = claimed[0].data.id;
sqlx::query!(
"UPDATE requests SET state = 'processing', started_at = NOW() WHERE id = $1",
*request_id as Uuid
)
.execute(&pool)
.await
.unwrap();
let daemon2_id = DaemonId::from(Uuid::new_v4());
let claimed2 = manager
.claim_requests(1, daemon2_id, &capacity, &HashMap::new())
.await
.unwrap();
assert_eq!(claimed2.len(), 1);
assert_ne!(claimed2[0].data.id, request_id);
let results = manager.get_requests(vec![request_id]).await.unwrap();
assert!(
matches!(&results[0], Ok(crate::AnyRequest::Processing(_))),
"Request should still be in processing state"
);
}
#[sqlx::test]
async fn test_dont_unclaim_recent_requests(pool: sqlx::PgPool) {
let http_client = Arc::new(MockHttpClient::new());
let config = crate::daemon::DaemonConfig {
claim_timeout_ms: 60000, processing_timeout_ms: 600000, ..Default::default()
};
let manager = Arc::new(
PostgresRequestManager::with_client(
TestDbPools::new(pool.clone()).await.unwrap(),
http_client,
)
.with_config(config),
);
let file_id = manager
.create_file(
"recent-test".to_string(),
None,
vec![
RequestTemplateInput {
custom_id: None,
endpoint: "https://api.example.com".to_string(),
method: "POST".to_string(),
path: "/test".to_string(),
body: r#"{"n":1}"#.to_string(),
model: "test".to_string(),
api_key: "key".to_string(),
},
RequestTemplateInput {
custom_id: None,
endpoint: "https://api.example.com".to_string(),
method: "POST".to_string(),
path: "/test".to_string(),
body: r#"{"n":2}"#.to_string(),
model: "test".to_string(),
api_key: "key".to_string(),
},
],
)
.await
.unwrap();
manager
.create_batch(crate::batch::BatchInput {
file_id,
endpoint: "/v1/chat/completions".to_string(),
completion_window: "24h".to_string(),
metadata: None,
created_by: None,
api_key_id: None,
api_key: None,
total_requests: None,
})
.await
.unwrap();
let capacity = HashMap::from([("test".to_string(), 10)]);
let daemon1_id = DaemonId::from(Uuid::new_v4());
let claimed1 = manager
.claim_requests(1, daemon1_id, &capacity, &HashMap::new())
.await
.unwrap();
assert_eq!(claimed1.len(), 1);
let daemon2_id = DaemonId::from(Uuid::new_v4());
let claimed2 = manager
.claim_requests(1, daemon2_id, &capacity, &HashMap::new())
.await
.unwrap();
assert_eq!(claimed2.len(), 1);
assert_ne!(claimed1[0].data.id, claimed2[0].data.id);
let results = manager
.get_requests(vec![claimed1[0].data.id])
.await
.unwrap();
if let Ok(crate::AnyRequest::Claimed(req)) = &results[0] {
assert_eq!(req.state.daemon_id, daemon1_id);
} else {
panic!("Request should still be claimed by daemon1");
}
}
#[sqlx::test]
async fn test_preserve_retry_attempt_on_unclaim(pool: sqlx::PgPool) {
let http_client = Arc::new(MockHttpClient::new());
let config = crate::daemon::DaemonConfig {
claim_timeout_ms: 1000,
processing_timeout_ms: 60000,
..Default::default()
};
let manager = Arc::new(
PostgresRequestManager::with_client(
TestDbPools::new(pool.clone()).await.unwrap(),
http_client,
)
.with_config(config),
);
let file_id = manager
.create_file(
"retry-test".to_string(),
None,
vec![RequestTemplateInput {
custom_id: None,
endpoint: "https://api.example.com".to_string(),
method: "POST".to_string(),
path: "/test".to_string(),
body: "{}".to_string(),
model: "test".to_string(),
api_key: "key".to_string(),
}],
)
.await
.unwrap();
manager
.create_batch(crate::batch::BatchInput {
file_id,
endpoint: "/v1/chat/completions".to_string(),
completion_window: "24h".to_string(),
metadata: None,
created_by: None,
api_key_id: None,
api_key: None,
total_requests: None,
})
.await
.unwrap();
sqlx::query!(
r#"
UPDATE requests
SET
retry_attempt = 2,
state = 'claimed',
daemon_id = $1,
claimed_at = NOW() - INTERVAL '3 seconds'
WHERE id IN (SELECT id FROM requests WHERE state = 'pending' LIMIT 1)
RETURNING id
"#,
Uuid::new_v4()
)
.fetch_one(&pool)
.await
.unwrap();
let capacity = HashMap::from([("test".to_string(), 10)]);
let daemon_id = DaemonId::from(Uuid::new_v4());
let claimed = manager
.claim_requests(1, daemon_id, &capacity, &HashMap::new())
.await
.unwrap();
assert_eq!(claimed.len(), 1);
assert_eq!(claimed[0].state.retry_attempt, 2);
}
#[sqlx::test]
async fn test_batch_output_and_error_streaming(pool: sqlx::PgPool) {
use futures::StreamExt;
let http_client = Arc::new(MockHttpClient::new());
let manager = PostgresRequestManager::with_client(
TestDbPools::new(pool.clone()).await.unwrap(),
http_client,
);
let file_id = manager
.create_file(
"streaming-test".to_string(),
None,
vec![
RequestTemplateInput {
custom_id: Some("req-1".to_string()),
endpoint: "https://api.example.com".to_string(),
method: "POST".to_string(),
path: "/v1/chat/completions".to_string(),
body: r#"{"prompt":"first"}"#.to_string(),
model: "gpt-4".to_string(),
api_key: "key".to_string(),
},
RequestTemplateInput {
custom_id: Some("req-2".to_string()),
endpoint: "https://api.example.com".to_string(),
method: "POST".to_string(),
path: "/v1/chat/completions".to_string(),
body: r#"{"prompt":"second"}"#.to_string(),
model: "gpt-4".to_string(),
api_key: "key".to_string(),
},
RequestTemplateInput {
custom_id: Some("req-3".to_string()),
endpoint: "https://api.example.com".to_string(),
method: "POST".to_string(),
path: "/v1/chat/completions".to_string(),
body: r#"{"prompt":"third"}"#.to_string(),
model: "gpt-4".to_string(),
api_key: "key".to_string(),
},
],
)
.await
.expect("Failed to create file");
let batch = manager
.create_batch(crate::batch::BatchInput {
file_id,
endpoint: "/v1/chat/completions".to_string(),
completion_window: "24h".to_string(),
metadata: None,
created_by: Some("test-user".to_string()),
api_key_id: None,
api_key: None,
total_requests: None,
})
.await
.expect("Failed to create batch");
assert!(batch.output_file_id.is_some());
assert!(batch.error_file_id.is_some());
let output_file_id = batch.output_file_id.unwrap();
let error_file_id = batch.error_file_id.unwrap();
let output_file = manager
.get_file(output_file_id)
.await
.expect("Failed to get output file");
let error_file = manager
.get_file(error_file_id)
.await
.expect("Failed to get error file");
assert_eq!(
output_file.name,
format!("batch-{}-output.jsonl", batch.id.0)
);
assert_eq!(error_file.name, format!("batch-{}-error.jsonl", batch.id.0));
let requests = manager
.get_batch_requests(batch.id)
.await
.expect("Failed to get requests");
assert_eq!(requests.len(), 3);
sqlx::query!(
r#"
UPDATE requests
SET state = 'completed',
response_status = 200,
response_body = $2,
completed_at = NOW()
WHERE id = $1
"#,
*requests[0].id() as Uuid,
r#"{"id":"chatcmpl-123","choices":[{"message":{"content":"Response 1"}}]}"#,
)
.execute(&pool)
.await
.expect("Failed to mark request as completed");
sqlx::query!(
r#"
UPDATE requests
SET state = 'completed',
response_status = 200,
response_body = $2,
completed_at = NOW()
WHERE id = $1
"#,
*requests[1].id() as Uuid,
r#"{"id":"chatcmpl-456","choices":[{"message":{"content":"Response 2"}}]}"#,
)
.execute(&pool)
.await
.expect("Failed to mark request as completed");
sqlx::query!(
r#"
UPDATE requests
SET state = 'failed',
error = $2,
failed_at = NOW()
WHERE id = $1
"#,
*requests[2].id() as Uuid,
"Rate limit exceeded",
)
.execute(&pool)
.await
.expect("Failed to mark request as failed");
let output_stream = manager.get_file_content_stream(output_file_id, 0, None);
let output_items: Vec<_> = output_stream.collect().await;
assert_eq!(output_items.len(), 2, "Should have 2 output items");
let mut found_custom_ids = Vec::new();
for item_result in output_items.iter() {
let item = item_result.as_ref().expect("Output item should be Ok");
match item {
FileContentItem::Output(output) => {
found_custom_ids.push(output.custom_id.clone());
assert_eq!(output.response.status_code, 200);
assert!(output.response.body.is_object());
assert!(output.error.is_none());
assert!(output.id.starts_with("batch_req_"));
}
_ => panic!("Expected FileContentItem::Output, got different type"),
}
}
found_custom_ids.sort();
assert_eq!(
found_custom_ids,
vec![Some("req-1".to_string()), Some("req-2".to_string())]
);
let error_stream = manager.get_file_content_stream(error_file_id, 0, None);
let error_items: Vec<_> = error_stream.collect().await;
assert_eq!(error_items.len(), 1, "Should have 1 error item");
let error_result = &error_items[0];
let error_item = error_result.as_ref().expect("Error item should be Ok");
match error_item {
FileContentItem::Error(error) => {
assert_eq!(error.custom_id, Some("req-3".to_string()));
assert_eq!(error.error.message, "Rate limit exceeded");
assert!(error.response.is_none());
assert!(error.id.starts_with("batch_req_"));
}
_ => panic!("Expected FileContentItem::Error, got different type"),
}
let input_stream = manager.get_file_content_stream(file_id, 0, None);
let input_items: Vec<_> = input_stream.collect().await;
assert_eq!(input_items.len(), 3, "Input file should have 3 templates");
for item_result in input_items {
let item = item_result.expect("Input item should be Ok");
match item {
FileContentItem::Template(_) => {
}
_ => panic!("Expected FileContentItem::Template for input file"),
}
}
}
#[sqlx::test]
async fn test_daemon_persist_and_get(pool: sqlx::PgPool) {
let http_client = Arc::new(MockHttpClient::new());
let manager = PostgresRequestManager::with_client(
TestDbPools::new(pool.clone()).await.unwrap(),
http_client,
);
let daemon_id = DaemonId(Uuid::new_v4());
let daemon_data = DaemonData {
id: daemon_id,
hostname: "test-host".to_string(),
pid: 12345,
version: "1.0.0".to_string(),
config_snapshot: serde_json::json!({"test": "config"}),
};
let initializing = DaemonRecord {
data: daemon_data.clone(),
state: Initializing {
started_at: Utc::now(),
},
};
manager.persist_daemon(&initializing).await.unwrap();
let retrieved = manager.get_daemon(daemon_id).await.unwrap();
match retrieved {
AnyDaemonRecord::Initializing(d) => {
assert_eq!(d.data.id, daemon_id);
assert_eq!(d.data.hostname, "test-host");
}
_ => panic!("Expected Initializing state"),
}
let running = DaemonRecord {
data: daemon_data.clone(),
state: Running {
started_at: Utc::now(),
last_heartbeat: Utc::now(),
stats: DaemonStats {
requests_processed: 10,
requests_failed: 2,
requests_in_flight: 3,
},
},
};
manager.persist_daemon(&running).await.unwrap();
let retrieved = manager.get_daemon(daemon_id).await.unwrap();
match retrieved {
AnyDaemonRecord::Running(d) => {
assert_eq!(d.data.id, daemon_id);
assert_eq!(d.state.stats.requests_processed, 10);
assert_eq!(d.state.stats.requests_failed, 2);
assert_eq!(d.state.stats.requests_in_flight, 3);
}
_ => panic!("Expected Running state"),
}
let dead = DaemonRecord {
data: daemon_data,
state: Dead {
started_at: Utc::now() - chrono::Duration::hours(1),
stopped_at: Utc::now(),
final_stats: DaemonStats {
requests_processed: 100,
requests_failed: 5,
requests_in_flight: 0,
},
},
};
manager.persist_daemon(&dead).await.unwrap();
let retrieved = manager.get_daemon(daemon_id).await.unwrap();
match retrieved {
AnyDaemonRecord::Dead(d) => {
assert_eq!(d.data.id, daemon_id);
assert_eq!(d.state.final_stats.requests_processed, 100);
assert_eq!(d.state.final_stats.requests_failed, 5);
}
_ => panic!("Expected Dead state"),
}
}
#[sqlx::test]
async fn test_daemon_list_all(pool: sqlx::PgPool) {
let http_client = Arc::new(MockHttpClient::new());
let manager = PostgresRequestManager::with_client(
TestDbPools::new(pool.clone()).await.unwrap(),
http_client,
);
let daemon1 = DaemonRecord {
data: DaemonData {
id: DaemonId(Uuid::new_v4()),
hostname: "host1".to_string(),
pid: 1001,
version: "1.0.0".to_string(),
config_snapshot: serde_json::json!({}),
},
state: Running {
started_at: Utc::now(),
last_heartbeat: Utc::now(),
stats: DaemonStats::default(),
},
};
let daemon2 = DaemonRecord {
data: DaemonData {
id: DaemonId(Uuid::new_v4()),
hostname: "host2".to_string(),
pid: 1002,
version: "1.0.0".to_string(),
config_snapshot: serde_json::json!({}),
},
state: Running {
started_at: Utc::now(),
last_heartbeat: Utc::now(),
stats: DaemonStats::default(),
},
};
let daemon3 = DaemonRecord {
data: DaemonData {
id: DaemonId(Uuid::new_v4()),
hostname: "host3".to_string(),
pid: 1003,
version: "1.0.0".to_string(),
config_snapshot: serde_json::json!({}),
},
state: Dead {
started_at: Utc::now() - chrono::Duration::hours(1),
stopped_at: Utc::now(),
final_stats: DaemonStats::default(),
},
};
manager.persist_daemon(&daemon1).await.unwrap();
manager.persist_daemon(&daemon2).await.unwrap();
manager.persist_daemon(&daemon3).await.unwrap();
let all = manager.list_daemons(None).await.unwrap();
assert_eq!(all.len(), 3);
let running = manager
.list_daemons(Some(DaemonStatus::Running))
.await
.unwrap();
assert_eq!(running.len(), 2);
let dead = manager
.list_daemons(Some(DaemonStatus::Dead))
.await
.unwrap();
assert_eq!(dead.len(), 1);
}
#[sqlx::test]
async fn test_daemon_heartbeat_updates(pool: sqlx::PgPool) {
let http_client = Arc::new(MockHttpClient::new());
let manager = PostgresRequestManager::with_client(
TestDbPools::new(pool.clone()).await.unwrap(),
http_client,
);
let daemon_id = DaemonId(Uuid::new_v4());
let daemon_data = DaemonData {
id: daemon_id,
hostname: "test-host".to_string(),
pid: 12345,
version: "1.0.0".to_string(),
config_snapshot: serde_json::json!({}),
};
let running = DaemonRecord {
data: daemon_data,
state: Running {
started_at: Utc::now(),
last_heartbeat: Utc::now(),
stats: DaemonStats {
requests_processed: 0,
requests_failed: 0,
requests_in_flight: 0,
},
},
};
manager.persist_daemon(&running).await.unwrap();
for i in 1..=3 {
let updated = DaemonRecord {
data: running.data.clone(),
state: Running {
started_at: running.state.started_at,
last_heartbeat: Utc::now(),
stats: DaemonStats {
requests_processed: i * 10,
requests_failed: i,
requests_in_flight: i as usize,
},
},
};
manager.persist_daemon(&updated).await.unwrap();
}
let retrieved = manager.get_daemon(daemon_id).await.unwrap();
match retrieved {
AnyDaemonRecord::Running(d) => {
assert_eq!(d.state.stats.requests_processed, 30);
assert_eq!(d.state.stats.requests_failed, 3);
assert_eq!(d.state.stats.requests_in_flight, 3);
}
_ => panic!("Expected Running state"),
}
}
#[sqlx::test]
async fn test_create_file_stream_with_metadata_and_templates(pool: sqlx::PgPool) {
use crate::batch::{FileMetadata, FileStreamItem};
use futures::stream;
let http_client = Arc::new(MockHttpClient::new());
let manager = PostgresRequestManager::with_client(
TestDbPools::new(pool.clone()).await.unwrap(),
http_client,
);
let items = vec![
FileStreamItem::Metadata(FileMetadata {
filename: Some("streamed-file".to_string()),
description: Some("A file created via streaming".to_string()),
purpose: None,
expires_after_anchor: None,
expires_after_seconds: None,
size_bytes: None,
uploaded_by: Some("test-user".to_string()),
api_key_id: None,
}),
FileStreamItem::Template(RequestTemplateInput {
custom_id: Some("stream-1".to_string()),
endpoint: "https://api.example.com".to_string(),
method: "POST".to_string(),
path: "/v1/completions".to_string(),
body: r#"{"prompt":"first"}"#.to_string(),
model: "gpt-4".to_string(),
api_key: "key1".to_string(),
}),
FileStreamItem::Template(RequestTemplateInput {
custom_id: Some("stream-2".to_string()),
endpoint: "https://api.example.com".to_string(),
method: "POST".to_string(),
path: "/v1/completions".to_string(),
body: r#"{"prompt":"second"}"#.to_string(),
model: "gpt-3.5".to_string(),
api_key: "key2".to_string(),
}),
];
let stream = stream::iter(items);
let file_id = expect_stream_success(
manager
.create_file_stream(stream)
.await
.expect("Failed to create file from stream"),
);
let file = manager.get_file(file_id).await.expect("Failed to get file");
assert_eq!(file.name, "streamed-file");
assert_eq!(
file.description,
Some("A file created via streaming".to_string())
);
let content = manager
.get_file_content(file_id)
.await
.expect("Failed to get content");
assert_eq!(content.len(), 2);
match &content[0] {
FileContentItem::Template(t) => {
assert_eq!(t.custom_id, Some("stream-1".to_string()));
assert_eq!(t.model, "gpt-4");
}
_ => panic!("Expected template"),
}
match &content[1] {
FileContentItem::Template(t) => {
assert_eq!(t.custom_id, Some("stream-2".to_string()));
assert_eq!(t.model, "gpt-3.5");
}
_ => panic!("Expected template"),
}
}
#[sqlx::test]
async fn test_create_file_stream_templates_before_metadata(pool: sqlx::PgPool) {
use crate::batch::{FileMetadata, FileStreamItem};
use futures::stream;
let http_client = Arc::new(MockHttpClient::new());
let manager = PostgresRequestManager::with_client(
TestDbPools::new(pool.clone()).await.unwrap(),
http_client,
);
let items = vec![
FileStreamItem::Template(RequestTemplateInput {
custom_id: None,
endpoint: "https://api.example.com".to_string(),
method: "POST".to_string(),
path: "/test".to_string(),
body: r#"{"n":1}"#.to_string(),
model: "test".to_string(),
api_key: "key".to_string(),
}),
FileStreamItem::Metadata(FileMetadata {
filename: Some("late-metadata".to_string()),
description: Some("Metadata came late".to_string()),
purpose: None,
expires_after_anchor: None,
expires_after_seconds: None,
size_bytes: None,
uploaded_by: Some("test-user".to_string()),
api_key_id: None,
}),
FileStreamItem::Template(RequestTemplateInput {
custom_id: None,
endpoint: "https://api.example.com".to_string(),
method: "POST".to_string(),
path: "/test".to_string(),
body: r#"{"n":2}"#.to_string(),
model: "test".to_string(),
api_key: "key".to_string(),
}),
];
let stream = stream::iter(items);
let file_id = expect_stream_success(
manager
.create_file_stream(stream)
.await
.expect("Failed to create file from stream"),
);
let file = manager.get_file(file_id).await.unwrap();
assert_eq!(file.name, "late-metadata");
assert_eq!(file.description, Some("Metadata came late".to_string()));
let content = manager.get_file_content(file_id).await.unwrap();
assert_eq!(content.len(), 2);
}
#[sqlx::test]
async fn test_create_file_stream_abort_handling(pool: sqlx::PgPool) {
use crate::batch::FileStreamItem;
use futures::stream;
let http_client = Arc::new(MockHttpClient::new());
let manager = PostgresRequestManager::with_client(
TestDbPools::new(pool.clone()).await.unwrap(),
http_client,
);
let items = vec![
FileStreamItem::Template(RequestTemplateInput {
custom_id: None,
endpoint: "https://api.example.com".to_string(),
method: "POST".to_string(),
path: "/test".to_string(),
body: r#"{"n":1}"#.to_string(),
model: "test".to_string(),
api_key: "key".to_string(),
}),
FileStreamItem::Abort,
FileStreamItem::Template(RequestTemplateInput {
custom_id: None,
endpoint: "https://api.example.com".to_string(),
method: "POST".to_string(),
path: "/test".to_string(),
body: r#"{"n":2}"#.to_string(),
model: "test".to_string(),
api_key: "key".to_string(),
}),
];
let stream = stream::iter(items);
let result = manager.create_file_stream(stream).await;
match result {
Ok(FileStreamResult::Aborted) => {}
_ => panic!("Expected Aborted"),
}
let files = sqlx::query(r#"SELECT COUNT(*) as count FROM files"#)
.fetch_one(&pool)
.await
.unwrap();
let count: i64 = files.get("count");
assert_eq!(count, 0, "Aborted stream should roll back inserts");
}
#[sqlx::test]
#[allow(deprecated)]
async fn test_create_file_stream_deprecated_error_handling(pool: sqlx::PgPool) {
use crate::batch::FileStreamItem;
use futures::stream;
let http_client = Arc::new(MockHttpClient::new());
let manager = PostgresRequestManager::with_client(
TestDbPools::new(pool.clone()).await.unwrap(),
http_client,
);
let items = vec![
FileStreamItem::Template(RequestTemplateInput {
custom_id: None,
endpoint: "https://api.example.com".to_string(),
method: "POST".to_string(),
path: "/test".to_string(),
body: r#"{"n":1}"#.to_string(),
model: "test".to_string(),
api_key: "key".to_string(),
}),
FileStreamItem::Error("Invalid JSON on line 2".to_string()),
];
let result = manager.create_file_stream(stream::iter(items)).await;
match result {
Err(FusilladeError::ValidationError(msg)) => {
assert_eq!(msg, "Invalid JSON on line 2");
}
_ => panic!("Expected ValidationError"),
}
}
#[sqlx::test]
async fn test_get_batch(pool: sqlx::PgPool) {
let http_client = Arc::new(MockHttpClient::new());
let manager = PostgresRequestManager::with_client(
TestDbPools::new(pool.clone()).await.unwrap(),
http_client,
);
let file_id = manager
.create_file(
"batch-retrieval-test".to_string(),
Some("Test file for batch retrieval".to_string()),
vec![RequestTemplateInput {
custom_id: Some("req-1".to_string()),
endpoint: "https://api.example.com".to_string(),
method: "POST".to_string(),
path: "/v1/completions".to_string(),
body: r#"{"prompt":"test"}"#.to_string(),
model: "gpt-4".to_string(),
api_key: "key".to_string(),
}],
)
.await
.unwrap();
let batch_input = crate::batch::BatchInput {
file_id,
endpoint: "/v1/chat/completions".to_string(),
completion_window: "24h".to_string(),
metadata: Some(serde_json::json!({"project": "test"})),
created_by: Some("test-user".to_string()),
api_key_id: None,
api_key: None,
total_requests: None,
};
let created_batch = manager.create_batch(batch_input).await.unwrap();
let retrieved_batch = manager
.get_batch(created_batch.id)
.await
.expect("Failed to get batch");
assert_eq!(retrieved_batch.id, created_batch.id);
assert_eq!(retrieved_batch.file_id, Some(file_id));
assert_eq!(retrieved_batch.endpoint, "/v1/chat/completions");
assert_eq!(retrieved_batch.completion_window, "24h");
assert_eq!(
retrieved_batch.metadata,
Some(serde_json::json!({"project": "test"}))
);
assert_eq!(retrieved_batch.created_by, "test-user");
assert!(retrieved_batch.output_file_id.is_some());
assert!(retrieved_batch.error_file_id.is_some());
assert_eq!(retrieved_batch.total_requests, 1);
assert_eq!(retrieved_batch.pending_requests, 1);
assert_eq!(retrieved_batch.completed_requests, 0);
}
#[sqlx::test]
async fn test_get_batch_not_found(pool: sqlx::PgPool) {
let http_client = Arc::new(MockHttpClient::new());
let manager = PostgresRequestManager::with_client(
TestDbPools::new(pool.clone()).await.unwrap(),
http_client,
);
let fake_batch_id = BatchId(Uuid::new_v4());
let result = manager.get_batch(fake_batch_id).await;
assert!(result.is_err());
match result {
Err(FusilladeError::Other(e)) => {
assert!(e.to_string().contains("Batch not found"));
}
_ => panic!("Expected Other error with 'Batch not found' message"),
}
}
#[sqlx::test]
async fn test_get_batch_with_progress(pool: sqlx::PgPool) {
let http_client = Arc::new(MockHttpClient::new());
let manager = PostgresRequestManager::with_client(
TestDbPools::new(pool.clone()).await.unwrap(),
http_client,
);
let file_id = manager
.create_file(
"progress-test".to_string(),
None,
(0..5)
.map(|i| RequestTemplateInput {
custom_id: None,
endpoint: "https://api.example.com".to_string(),
method: "POST".to_string(),
path: "/test".to_string(),
body: format!(r#"{{"n":{}}}"#, i),
model: "test".to_string(),
api_key: "key".to_string(),
})
.collect(),
)
.await
.unwrap();
let batch = manager
.create_batch(crate::batch::BatchInput {
file_id,
endpoint: "/v1/chat/completions".to_string(),
completion_window: "24h".to_string(),
metadata: None,
created_by: None,
api_key_id: None,
api_key: None,
total_requests: None,
})
.await
.unwrap();
let capacity = HashMap::from([("test".to_string(), 10)]);
let daemon_id = DaemonId::from(Uuid::new_v4());
let claimed = manager
.claim_requests(2, daemon_id, &capacity, &HashMap::new())
.await
.unwrap();
sqlx::query!(
r#"
UPDATE requests
SET state = 'completed',
response_status = 200,
response_body = '{"result":"ok"}',
completed_at = NOW()
WHERE id = $1
"#,
*claimed[0].data.id as Uuid,
)
.execute(&pool)
.await
.unwrap();
let retrieved = manager.get_batch(batch.id).await.unwrap();
assert_eq!(retrieved.total_requests, 5);
assert_eq!(retrieved.pending_requests, 3);
assert_eq!(retrieved.in_progress_requests, 1); assert_eq!(retrieved.completed_requests, 1);
assert_eq!(retrieved.failed_requests, 0);
assert_eq!(retrieved.canceled_requests, 0);
}
#[sqlx::test]
async fn test_get_batch_lazy_finalization(pool: sqlx::PgPool) {
let http_client = Arc::new(MockHttpClient::new());
let manager = PostgresRequestManager::with_client(
TestDbPools::new(pool.clone()).await.unwrap(),
http_client,
);
let file_id = manager
.create_file(
"lazy-finalization-test".to_string(),
None,
(0..3)
.map(|i| RequestTemplateInput {
custom_id: Some(format!("req-{}", i)),
endpoint: "https://api.example.com".to_string(),
method: "POST".to_string(),
path: "/test".to_string(),
body: format!(r#"{{"n":{}}}"#, i),
model: "test".to_string(),
api_key: "key".to_string(),
})
.collect(),
)
.await
.unwrap();
let batch = manager
.create_batch(crate::batch::BatchInput {
file_id,
endpoint: "/v1/chat/completions".to_string(),
completion_window: "24h".to_string(),
metadata: None,
created_by: None,
api_key_id: None,
api_key: None,
total_requests: None,
})
.await
.unwrap();
sqlx::query!(
r#"
UPDATE requests
SET state = 'completed',
response_status = 200,
response_body = '{"result":"ok"}',
completed_at = NOW()
WHERE batch_id = $1
"#,
*batch.id as Uuid,
)
.execute(&pool)
.await
.unwrap();
let retrieved = manager.get_batch(batch.id).await.unwrap();
assert_eq!(retrieved.total_requests, 3);
assert_eq!(retrieved.completed_requests, 3);
assert_eq!(retrieved.pending_requests, 0);
assert_eq!(retrieved.failed_requests, 0);
assert!(
retrieved.finalizing_at.is_some(),
"finalizing_at should be set by lazy finalization"
);
assert!(
retrieved.completed_at.is_some(),
"completed_at should be set by lazy finalization"
);
assert!(
retrieved.failed_at.is_none(),
"failed_at should be None for completed batch"
);
let retrieved_again = manager.get_batch(batch.id).await.unwrap();
let truncate_nanos = |ts: Option<chrono::DateTime<chrono::Utc>>| {
ts.map(|t| t.with_nanosecond(t.nanosecond() / 1000 * 1000).unwrap())
};
assert_eq!(
truncate_nanos(retrieved.finalizing_at),
truncate_nanos(retrieved_again.finalizing_at)
);
assert_eq!(
truncate_nanos(retrieved.completed_at),
truncate_nanos(retrieved_again.completed_at)
);
}
#[sqlx::test]
async fn test_get_requests_various_states(pool: sqlx::PgPool) {
let http_client = Arc::new(MockHttpClient::new());
let manager = PostgresRequestManager::with_client(
TestDbPools::new(pool.clone()).await.unwrap(),
http_client,
);
let file_id = manager
.create_file(
"get-requests-test".to_string(),
None,
(0..5)
.map(|i| RequestTemplateInput {
custom_id: Some(format!("req-{}", i)),
endpoint: "https://api.example.com".to_string(),
method: "POST".to_string(),
path: "/test".to_string(),
body: format!(r#"{{"n":{}}}"#, i),
model: "test".to_string(),
api_key: "key".to_string(),
})
.collect(),
)
.await
.unwrap();
let batch = manager
.create_batch(crate::batch::BatchInput {
file_id,
endpoint: "/v1/chat/completions".to_string(),
completion_window: "24h".to_string(),
metadata: None,
created_by: None,
api_key_id: None,
api_key: None,
total_requests: None,
})
.await
.unwrap();
let all_requests = manager.get_batch_requests(batch.id).await.unwrap();
let request_ids: Vec<_> = all_requests.iter().map(|r| r.id()).collect();
let capacity = HashMap::from([("test".to_string(), 10)]);
let daemon_id = DaemonId::from(Uuid::new_v4());
let claimed = manager
.claim_requests(2, daemon_id, &capacity, &HashMap::new())
.await
.unwrap();
let claimed_ids: Vec<_> = claimed.iter().map(|r| r.data.id).collect();
sqlx::query!(
r#"
UPDATE requests
SET state = 'completed',
started_at = NOW() - INTERVAL '1 minute',
response_status = 200,
response_body = '{"done":true}',
completed_at = NOW()
WHERE id = $1
"#,
*claimed_ids[0] as Uuid,
)
.execute(&pool)
.await
.unwrap();
let pending_id = request_ids
.iter()
.find(|id| !claimed_ids.contains(id))
.unwrap();
sqlx::query!(
r#"
UPDATE requests
SET state = 'failed',
error = 'Rate limit exceeded',
failed_at = NOW()
WHERE id = $1
"#,
**pending_id as Uuid,
)
.execute(&pool)
.await
.unwrap();
let results = manager.get_requests(request_ids.clone()).await.unwrap();
assert_eq!(results.len(), 5);
let states: Vec<_> = results
.iter()
.map(|r| match r {
Ok(AnyRequest::Pending(_)) => "pending",
Ok(AnyRequest::Claimed(_)) => "claimed",
Ok(AnyRequest::Processing(_)) => "processing",
Ok(AnyRequest::Completed(_)) => "completed",
Ok(AnyRequest::Failed(_)) => "failed",
Ok(AnyRequest::Canceled(_)) => "canceled",
Err(_) => "error",
})
.collect();
assert_eq!(states.iter().filter(|&&s| s == "completed").count(), 1);
assert_eq!(states.iter().filter(|&&s| s == "failed").count(), 1);
assert_eq!(states.iter().filter(|&&s| s == "claimed").count(), 1);
assert_eq!(states.iter().filter(|&&s| s == "pending").count(), 2);
}
#[sqlx::test]
async fn test_get_requests_preserves_custom_ids(pool: sqlx::PgPool) {
let http_client = Arc::new(MockHttpClient::new());
let manager = PostgresRequestManager::with_client(
TestDbPools::new(pool.clone()).await.unwrap(),
http_client,
);
let file_id = manager
.create_file(
"custom-id-test".to_string(),
None,
vec![
RequestTemplateInput {
custom_id: Some("my-custom-id-1".to_string()),
endpoint: "https://api.example.com".to_string(),
method: "POST".to_string(),
path: "/test".to_string(),
body: r#"{"test":1}"#.to_string(),
model: "test".to_string(),
api_key: "key".to_string(),
},
RequestTemplateInput {
custom_id: Some("my-custom-id-2".to_string()),
endpoint: "https://api.example.com".to_string(),
method: "POST".to_string(),
path: "/test".to_string(),
body: r#"{"test":2}"#.to_string(),
model: "test".to_string(),
api_key: "key".to_string(),
},
],
)
.await
.unwrap();
let batch = manager
.create_batch(crate::batch::BatchInput {
file_id,
endpoint: "/v1/chat/completions".to_string(),
completion_window: "24h".to_string(),
metadata: None,
created_by: None,
api_key_id: None,
api_key: None,
total_requests: None,
})
.await
.unwrap();
let all_requests = manager.get_batch_requests(batch.id).await.unwrap();
let request_ids: Vec<_> = all_requests.iter().map(|r| r.id()).collect();
let results = manager.get_requests(request_ids).await.unwrap();
assert_eq!(results.len(), 2);
for result in results {
let request = result.expect("Request should be Ok");
let custom_id = match &request {
AnyRequest::Pending(r) => &r.data.custom_id,
_ => panic!("Expected Pending"),
};
assert!(
custom_id == &Some("my-custom-id-1".to_string())
|| custom_id == &Some("my-custom-id-2".to_string())
);
}
}
#[sqlx::test]
async fn test_get_requests_with_nonexistent_ids(pool: sqlx::PgPool) {
let http_client = Arc::new(MockHttpClient::new());
let manager = PostgresRequestManager::with_client(
TestDbPools::new(pool.clone()).await.unwrap(),
http_client,
);
let file_id = manager
.create_file(
"mixed-ids-test".to_string(),
None,
vec![RequestTemplateInput {
custom_id: None,
endpoint: "https://api.example.com".to_string(),
method: "POST".to_string(),
path: "/test".to_string(),
body: r#"{}"#.to_string(),
model: "test".to_string(),
api_key: "key".to_string(),
}],
)
.await
.unwrap();
let batch = manager
.create_batch(crate::batch::BatchInput {
file_id,
endpoint: "/v1/chat/completions".to_string(),
completion_window: "24h".to_string(),
metadata: None,
created_by: None,
api_key_id: None,
api_key: None,
total_requests: None,
})
.await
.unwrap();
let all_requests = manager.get_batch_requests(batch.id).await.unwrap();
let real_id = all_requests[0].id();
let mixed_ids = vec![
real_id,
RequestId(Uuid::new_v4()),
RequestId(Uuid::new_v4()),
];
let results = manager.get_requests(mixed_ids).await.unwrap();
assert_eq!(results.len(), 3);
assert!(results[0].is_ok()); assert!(results[1].is_err()); assert!(results[2].is_err()); }
#[sqlx::test]
async fn test_per_daemon_limit_allows_independent_claiming(pool: sqlx::PgPool) {
let http_client = Arc::new(MockHttpClient::new());
let config = crate::daemon::DaemonConfig::default();
config
.model_concurrency_limits
.insert("model-a".to_string(), 3);
config
.model_concurrency_limits
.insert("model-b".to_string(), 3);
let manager = Arc::new(
PostgresRequestManager::with_client(
TestDbPools::new(pool.clone()).await.unwrap(),
http_client,
)
.with_config(config),
);
let mut templates = Vec::new();
for model in &["model-a", "model-b"] {
for n in 1..=10 {
templates.push(RequestTemplateInput {
custom_id: None,
endpoint: "https://api.example.com".to_string(),
method: "POST".to_string(),
path: "/test".to_string(),
body: format!(r#"{{"model":"{}","n":{}}}"#, model, n),
model: model.to_string(),
api_key: "key".to_string(),
});
}
}
let file_id = manager
.create_file("multi-daemon-test".to_string(), None, templates)
.await
.unwrap();
let _batch = manager
.create_batch(crate::batch::BatchInput {
file_id,
endpoint: "/v1/chat/completions".to_string(),
completion_window: "24h".to_string(),
metadata: None,
created_by: None,
api_key_id: None,
api_key: None,
total_requests: None,
})
.await
.unwrap();
let daemon1_id = DaemonId::from(Uuid::new_v4());
let daemon2_id = DaemonId::from(Uuid::new_v4());
let daemon3_id = DaemonId::from(Uuid::new_v4());
let full_capacity: std::collections::HashMap<String, usize> =
[("model-a".to_string(), 3), ("model-b".to_string(), 3)].into();
let claimed1 = manager
.claim_requests(10, daemon1_id, &full_capacity, &HashMap::new())
.await
.unwrap();
let claimed2 = manager
.claim_requests(10, daemon2_id, &full_capacity, &HashMap::new())
.await
.unwrap();
let claimed3 = manager
.claim_requests(10, daemon3_id, &full_capacity, &HashMap::new())
.await
.unwrap();
let mut per_daemon_model_counts: Vec<std::collections::HashMap<String, i32>> = Vec::new();
for claimed in [&claimed1, &claimed2, &claimed3] {
let mut counts = std::collections::HashMap::new();
for request in claimed {
*counts.entry(request.data.model.clone()).or_insert(0) += 1;
}
per_daemon_model_counts.push(counts);
}
for (i, counts) in per_daemon_model_counts.iter().enumerate() {
for (model, count) in counts {
assert!(
*count <= 3,
"Daemon {} claimed {} requests for {}, exceeding per-daemon limit of 3",
i + 1,
count,
model,
);
}
}
let total = claimed1.len() + claimed2.len() + claimed3.len();
assert!(
total <= 18,
"Total claimed should not exceed 18 (3 per model × 2 models × 3 daemons), got {}",
total,
);
}
#[sqlx::test]
async fn test_auto_generated_filename_then_real_filename_differs(pool: sqlx::PgPool) {
use crate::batch::{FileMetadata, FileStreamItem};
use futures::stream;
let http_client = Arc::new(MockHttpClient::new());
let manager = PostgresRequestManager::with_client(
TestDbPools::new(pool.clone()).await.unwrap(),
http_client,
);
let items = vec![
FileStreamItem::Metadata(FileMetadata {
uploaded_by: Some("user1".to_string()),
..Default::default()
}),
FileStreamItem::Template(RequestTemplateInput {
custom_id: Some("test-1".to_string()),
endpoint: "https://api.example.com".to_string(),
method: "POST".to_string(),
path: "/test".to_string(),
body: "{}".to_string(),
model: "test".to_string(),
api_key: "key".to_string(),
}),
FileStreamItem::Metadata(FileMetadata {
filename: Some("updated-filename.jsonl".to_string()),
uploaded_by: Some("user1".to_string()),
..Default::default()
}),
FileStreamItem::Template(RequestTemplateInput {
custom_id: Some("test-2".to_string()),
endpoint: "https://api.example.com".to_string(),
method: "POST".to_string(),
path: "/test".to_string(),
body: "{}".to_string(),
model: "test".to_string(),
api_key: "key".to_string(),
}),
];
let file_id = expect_stream_success(
manager
.create_file_stream(stream::iter(items))
.await
.expect("Should create file successfully"),
);
let file = manager.get_file(file_id).await.unwrap();
assert_eq!(file.name, "updated-filename.jsonl");
assert_eq!(file.uploaded_by, Some("user1".to_string()));
let content = manager.get_file_content(file_id).await.unwrap();
assert_eq!(content.len(), 2);
}
#[sqlx::test]
async fn test_multiple_metadata_updates_last_wins(pool: sqlx::PgPool) {
use crate::batch::{FileMetadata, FileStreamItem};
use futures::stream;
let http_client = Arc::new(MockHttpClient::new());
let manager = PostgresRequestManager::with_client(
TestDbPools::new(pool.clone()).await.unwrap(),
http_client,
);
let items = vec![
FileStreamItem::Metadata(FileMetadata {
filename: Some("first-name.jsonl".to_string()),
description: Some("First description".to_string()),
uploaded_by: Some("user1".to_string()),
..Default::default()
}),
FileStreamItem::Metadata(FileMetadata {
filename: Some("second-name.jsonl".to_string()),
description: Some("Second description".to_string()),
uploaded_by: Some("user1".to_string()),
..Default::default()
}),
FileStreamItem::Metadata(FileMetadata {
filename: Some("final-name.jsonl".to_string()),
description: Some("Final description".to_string()),
uploaded_by: Some("user1".to_string()),
..Default::default()
}),
FileStreamItem::Template(RequestTemplateInput {
custom_id: None,
endpoint: "https://api.example.com".to_string(),
method: "POST".to_string(),
path: "/test".to_string(),
body: "{}".to_string(),
model: "test".to_string(),
api_key: "key".to_string(),
}),
];
let file_id = expect_stream_success(
manager
.create_file_stream(stream::iter(items))
.await
.expect("Should create file"),
);
let file = manager.get_file(file_id).await.unwrap();
assert_eq!(file.name, "final-name.jsonl");
assert_eq!(file.description, Some("Final description".to_string()));
}
#[sqlx::test]
async fn test_empty_file_no_templates_but_with_filename(pool: sqlx::PgPool) {
use crate::batch::{FileMetadata, FileStreamItem};
use futures::stream;
let http_client = Arc::new(MockHttpClient::new());
let manager = PostgresRequestManager::with_client(
TestDbPools::new(pool.clone()).await.unwrap(),
http_client,
);
let items = vec![FileStreamItem::Metadata(FileMetadata {
filename: Some("empty-file.jsonl".to_string()),
description: Some("A file with no templates".to_string()),
uploaded_by: Some("user1".to_string()),
..Default::default()
})];
let file_id = expect_stream_success(
manager
.create_file_stream(stream::iter(items))
.await
.expect("Should create empty file"),
);
let file = manager.get_file(file_id).await.unwrap();
assert_eq!(file.name, "empty-file.jsonl");
assert_eq!(
file.description,
Some("A file with no templates".to_string())
);
let content = manager.get_file_content(file_id).await.unwrap();
assert_eq!(content.len(), 0);
}
async fn wait_for<F, Fut>(mut check: F, timeout: std::time::Duration) -> bool
where
F: FnMut() -> Fut,
Fut: std::future::Future<Output = bool>,
{
let start = std::time::Instant::now();
while start.elapsed() < timeout {
if check().await {
return true;
}
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
}
false
}
#[sqlx::test]
async fn test_batch_cancellation_with_stream(pool: sqlx::PgPool) {
use crate::http::HttpResponse;
use std::time::Duration;
let http_client = Arc::new(MockHttpClient::new());
let manager = Arc::new(PostgresRequestManager::with_client(
TestDbPools::new(pool.clone()).await.unwrap(),
http_client.clone(),
));
let file_id = manager
.create_file(
"test_cancellation".to_string(),
None,
vec![
RequestTemplateInput {
custom_id: Some("req1".to_string()),
endpoint: "http://example.com".to_string(),
method: "POST".to_string(),
path: "/test".to_string(),
body: r#"{"test": 1}"#.to_string(),
model: "model-a".to_string(),
api_key: "test-key".to_string(),
},
RequestTemplateInput {
custom_id: Some("req2".to_string()),
endpoint: "http://example.com".to_string(),
method: "POST".to_string(),
path: "/test".to_string(),
body: r#"{"test": 2}"#.to_string(),
model: "model-a".to_string(),
api_key: "test-key".to_string(),
},
],
)
.await
.unwrap();
let batch = manager
.create_batch(BatchInput {
file_id,
created_by: Some("test-user".to_string()),
completion_window: "24h".to_string(),
endpoint: "/v1/chat/completions".to_string(),
metadata: None,
api_key_id: None,
api_key: None,
total_requests: None,
})
.await
.unwrap();
http_client.clear_calls();
let trigger1 = http_client.add_response_with_trigger(
"POST /test",
Ok(HttpResponse {
status: 200,
body: "ok".to_string(),
}),
);
let _trigger2 = http_client.add_response_with_trigger(
"POST /test",
Ok(HttpResponse {
status: 200,
body: "ok".to_string(),
}),
);
let shutdown_token = tokio_util::sync::CancellationToken::new();
let model_concurrency_limits = Arc::new(dashmap::DashMap::new());
model_concurrency_limits.insert("model-a".to_string(), 5);
let config = crate::daemon::DaemonConfig {
claim_batch_size: 10,
model_concurrency_limits,
claim_interval_ms: 10,
max_retries: Some(10_000),
stop_before_deadline_ms: Some(900_000),
backoff_ms: 100,
backoff_factor: 2,
max_backoff_ms: 1000,
status_log_interval_ms: None,
heartbeat_interval_ms: 1000,
should_retry: Arc::new(|_| false),
claim_timeout_ms: 5000,
processing_timeout_ms: 10000,
..Default::default()
};
let daemon = Arc::new(crate::daemon::Daemon::new(
manager.clone(),
http_client.clone(),
config,
shutdown_token.clone(),
));
let daemon_handle = tokio::spawn({
let daemon = daemon.clone();
async move { daemon.run().await }
});
let requests = manager.get_batch_requests(batch.id).await.unwrap();
assert_eq!(requests.len(), 2);
let manager_clone = manager.clone();
let batch_id = batch.id;
let reached_processing = wait_for(
|| async {
if let Ok(reqs) = manager_clone.get_batch_requests(batch_id).await {
reqs.iter().all(|r| matches!(r, AnyRequest::Processing(_)))
} else {
false
}
},
Duration::from_secs(3),
)
.await;
assert!(
reached_processing,
"Both requests should reach processing state"
);
manager.cancel_batch(batch.id).await.unwrap();
let manager_clone = manager.clone();
let batch_shows_canceled = wait_for(
|| async {
if let Ok(status) = manager_clone.get_batch_status(batch_id).await {
return status.canceled_requests == 2 && status.in_progress_requests == 0;
}
false
},
Duration::from_secs(2), )
.await;
assert!(
batch_shows_canceled,
"Batch should show 2 canceled requests and 0 in_progress"
);
shutdown_token.cancel();
let _ = tokio::time::timeout(Duration::from_secs(5), daemon_handle).await;
drop(trigger1);
}
#[sqlx::test]
async fn test_virtual_files_lazy_finalized_via_get_file(pool: sqlx::PgPool) {
let http_client = Arc::new(MockHttpClient::new());
let manager = PostgresRequestManager::with_client(
TestDbPools::new(pool.clone()).await.unwrap(),
http_client,
);
let file_id = manager
.create_file(
"lazy-finalize-get-test".to_string(),
None,
vec![
RequestTemplateInput {
custom_id: Some("req-1".to_string()),
endpoint: "https://api.example.com".to_string(),
method: "POST".to_string(),
path: "/test".to_string(),
body: "{}".to_string(),
model: "test".to_string(),
api_key: "key".to_string(),
},
RequestTemplateInput {
custom_id: Some("req-2".to_string()),
endpoint: "https://api.example.com".to_string(),
method: "POST".to_string(),
path: "/test".to_string(),
body: "{}".to_string(),
model: "test".to_string(),
api_key: "key".to_string(),
},
RequestTemplateInput {
custom_id: Some("req-3".to_string()),
endpoint: "https://api.example.com".to_string(),
method: "POST".to_string(),
path: "/test".to_string(),
body: "{}".to_string(),
model: "test".to_string(),
api_key: "key".to_string(),
},
],
)
.await
.unwrap();
let batch = manager
.create_batch(BatchInput {
file_id,
endpoint: "/v1/test".to_string(),
completion_window: "24h".to_string(),
metadata: None,
created_by: None,
api_key_id: None,
api_key: None,
total_requests: None,
})
.await
.unwrap();
let output_file_id = batch.output_file_id.unwrap();
let error_file_id = batch.error_file_id.unwrap();
let requests = manager.get_batch_requests(batch.id).await.unwrap();
sqlx::query!(
r#"
UPDATE requests
SET state = 'completed',
response_status = 200,
response_body = $2,
response_size = $3,
completed_at = NOW()
WHERE id = $1
"#,
*requests[0].id() as Uuid,
r#"{"result":"success"}"#,
19i64, )
.execute(&pool)
.await
.unwrap();
sqlx::query!(
r#"
UPDATE requests
SET state = 'failed',
error = $2,
response_size = $3,
failed_at = NOW()
WHERE id = $1
"#,
*requests[1].id() as Uuid,
r#"{"code":"rate_limit","message":"Too many requests"}"#,
52i64, )
.execute(&pool)
.await
.unwrap();
let output_file = manager.get_file(output_file_id).await.unwrap();
assert!(
!output_file.size_finalized,
"Output file should not be finalized (batch incomplete)"
);
assert!(
output_file.size_bytes > 0,
"Output file should have estimated size > 0"
);
let error_file = manager.get_file(error_file_id).await.unwrap();
assert!(
!error_file.size_finalized,
"Error file should not be finalized (batch incomplete)"
);
assert!(
error_file.size_bytes > 0,
"Error file should have estimated size > 0"
);
sqlx::query!(
r#"
UPDATE requests
SET state = 'completed',
response_status = 200,
response_body = '{"done":true}',
response_size = 14,
completed_at = NOW()
WHERE id = $1
"#,
*requests[2].id() as Uuid,
)
.execute(&pool)
.await
.unwrap();
let output_file_after = manager.get_file(output_file_id).await.unwrap();
assert!(
output_file_after.size_finalized,
"Output file should be finalized after batch complete"
);
let error_file_after = manager.get_file(error_file_id).await.unwrap();
assert!(
error_file_after.size_finalized,
"Error file should be finalized after batch complete"
);
}
#[sqlx::test]
async fn test_virtual_files_lazy_finalized_via_list_files(pool: sqlx::PgPool) {
let http_client = Arc::new(MockHttpClient::new());
let manager = PostgresRequestManager::with_client(
TestDbPools::new(pool.clone()).await.unwrap(),
http_client,
);
let file_id = manager
.create_file(
"lazy-finalize-list-test".to_string(),
None,
vec![RequestTemplateInput {
custom_id: None,
endpoint: "https://api.example.com".to_string(),
method: "POST".to_string(),
path: "/test".to_string(),
body: "{}".to_string(),
model: "test".to_string(),
api_key: "key".to_string(),
}],
)
.await
.unwrap();
let batch = manager
.create_batch(BatchInput {
file_id,
endpoint: "/v1/test".to_string(),
completion_window: "24h".to_string(),
metadata: None,
created_by: Some("user1".to_string()),
api_key_id: None,
api_key: None,
total_requests: None,
})
.await
.unwrap();
let output_file_id = batch.output_file_id.unwrap();
sqlx::query!(
r#"
UPDATE requests
SET state = 'completed',
response_status = 200,
response_body = '{"ok":true}',
response_size = 12,
completed_at = NOW()
WHERE batch_id = $1
"#,
*batch.id as Uuid,
)
.execute(&pool)
.await
.unwrap();
let files = manager
.list_files(crate::batch::FileFilter {
purpose: Some(crate::batch::Purpose::BatchOutput.to_string()),
uploaded_by: Some("user1".to_string()),
..Default::default()
})
.await
.unwrap();
let output_file = files.iter().find(|f| f.id == output_file_id).unwrap();
assert!(output_file.size_bytes > 0, "Should have calculated size");
tokio::time::sleep(std::time::Duration::from_millis(200)).await;
let db_file = sqlx::query!(
"SELECT size_finalized FROM files WHERE id = $1",
*output_file_id as Uuid
)
.fetch_one(&pool)
.await
.unwrap();
assert!(db_file.size_finalized);
}
#[sqlx::test]
async fn test_list_files_respects_pagination_only_updates_current_page(pool: sqlx::PgPool) {
let http_client = Arc::new(MockHttpClient::new());
let manager = PostgresRequestManager::with_client(
TestDbPools::new(pool.clone()).await.unwrap(),
http_client,
);
let mut output_file_ids = Vec::new();
for i in 0..3 {
let file_id = manager
.create_file(
format!("batch-{}", i),
None,
vec![RequestTemplateInput {
custom_id: None,
endpoint: "https://api.example.com".to_string(),
method: "POST".to_string(),
path: "/test".to_string(),
body: "{}".to_string(),
model: "test".to_string(),
api_key: "key".to_string(),
}],
)
.await
.unwrap();
let batch = manager
.create_batch(BatchInput {
file_id,
endpoint: "/v1/test".to_string(),
completion_window: "24h".to_string(),
metadata: None,
created_by: Some("user1".to_string()),
api_key_id: None,
api_key: None,
total_requests: None,
})
.await
.unwrap();
output_file_ids.push(batch.output_file_id.unwrap());
sqlx::query!(
r#"
UPDATE requests
SET state = 'completed',
response_status = 200,
response_body = '{"done":true}',
response_size = 14,
completed_at = NOW()
WHERE batch_id = $1
"#,
*batch.id as Uuid,
)
.execute(&pool)
.await
.unwrap();
}
let page1 = manager
.list_files(crate::batch::FileFilter {
purpose: Some(crate::batch::Purpose::BatchOutput.to_string()),
uploaded_by: Some("user1".to_string()),
limit: Some(2),
..Default::default()
})
.await
.unwrap();
assert_eq!(page1.len(), 2, "First page should have 2 files");
tokio::time::sleep(std::time::Duration::from_millis(200)).await;
let finalized_count = sqlx::query_scalar!(
r#"
SELECT COUNT(*) as "count!"
FROM files
WHERE id = ANY($1) AND size_finalized = TRUE
"#,
&output_file_ids
.iter()
.map(|id| **id as Uuid)
.collect::<Vec<_>>(),
)
.fetch_one(&pool)
.await
.unwrap();
assert_eq!(finalized_count, 2, "Only page 1 files should be finalized");
let page2 = manager
.list_files(crate::batch::FileFilter {
purpose: Some(crate::batch::Purpose::BatchOutput.to_string()),
uploaded_by: Some("user1".to_string()),
after: Some(page1.last().unwrap().id),
limit: Some(2),
..Default::default()
})
.await
.unwrap();
assert_eq!(page2.len(), 1, "Second page should have 1 file");
tokio::time::sleep(std::time::Duration::from_millis(200)).await;
let all_finalized = sqlx::query_scalar!(
r#"
SELECT COUNT(*) as "count!"
FROM files
WHERE id = ANY($1) AND size_finalized = TRUE
"#,
&output_file_ids
.iter()
.map(|id| **id as Uuid)
.collect::<Vec<_>>(),
)
.fetch_one(&pool)
.await
.unwrap();
assert_eq!(all_finalized, 3, "All files should now be finalized");
}
#[sqlx::test]
async fn test_incomplete_batch_gives_estimate_not_finalized(pool: sqlx::PgPool) {
let http_client = Arc::new(MockHttpClient::new());
let manager = PostgresRequestManager::with_client(
TestDbPools::new(pool.clone()).await.unwrap(),
http_client,
);
let file_id = manager
.create_file(
"incomplete-batch-test".to_string(),
None,
(0..3)
.map(|_| RequestTemplateInput {
custom_id: None,
endpoint: "https://api.example.com".to_string(),
method: "POST".to_string(),
path: "/test".to_string(),
body: "{}".to_string(),
model: "test".to_string(),
api_key: "key".to_string(),
})
.collect(),
)
.await
.unwrap();
let batch = manager
.create_batch(BatchInput {
file_id,
endpoint: "/v1/test".to_string(),
completion_window: "24h".to_string(),
metadata: None,
created_by: None,
api_key_id: None,
api_key: None,
total_requests: None,
})
.await
.unwrap();
let output_file_id = batch.output_file_id.unwrap();
let error_file_id = batch.error_file_id.unwrap();
let requests = manager.get_batch_requests(batch.id).await.unwrap();
sqlx::query!(
r#"
UPDATE requests
SET state = 'completed',
response_status = 200,
response_body = '{"partial":true}',
response_size = 17,
completed_at = NOW()
WHERE id = $1
"#,
*requests[0].id() as Uuid,
)
.execute(&pool)
.await
.unwrap();
sqlx::query!(
r#"
UPDATE requests
SET state = 'failed',
error = '{"error":"test"}',
response_size = 16,
failed_at = NOW()
WHERE id = $1
"#,
*requests[1].id() as Uuid,
)
.execute(&pool)
.await
.unwrap();
for _ in 0..3 {
let output_file = manager.get_file(output_file_id).await.unwrap();
assert!(
!output_file.size_finalized,
"Output file should NOT be finalized (batch incomplete)"
);
assert!(output_file.size_bytes > 0, "Should have non-zero estimate");
let error_file = manager.get_file(error_file_id).await.unwrap();
assert!(
!error_file.size_finalized,
"Error file should NOT be finalized (batch incomplete)"
);
assert!(error_file.size_bytes > 0, "Should have non-zero estimate");
}
let output_db = sqlx::query!(
"SELECT size_finalized FROM files WHERE id = $1",
*output_file_id as Uuid
)
.fetch_one(&pool)
.await
.unwrap();
assert!(!output_db.size_finalized);
let error_db = sqlx::query!(
"SELECT size_finalized FROM files WHERE id = $1",
*error_file_id as Uuid
)
.fetch_one(&pool)
.await
.unwrap();
assert!(!error_db.size_finalized);
}
#[sqlx::test]
async fn test_finalized_file_uses_cached_value_no_recomputation(pool: sqlx::PgPool) {
let http_client = Arc::new(MockHttpClient::new());
let manager = PostgresRequestManager::with_client(
TestDbPools::new(pool.clone()).await.unwrap(),
http_client,
);
let file_id = manager
.create_file(
"cached-value-test".to_string(),
None,
vec![RequestTemplateInput {
custom_id: None,
endpoint: "https://api.example.com".to_string(),
method: "POST".to_string(),
path: "/test".to_string(),
body: "{}".to_string(),
model: "test".to_string(),
api_key: "key".to_string(),
}],
)
.await
.unwrap();
let batch = manager
.create_batch(BatchInput {
file_id,
endpoint: "/v1/test".to_string(),
completion_window: "24h".to_string(),
metadata: None,
created_by: None,
api_key_id: None,
api_key: None,
total_requests: None,
})
.await
.unwrap();
let output_file_id = batch.output_file_id.unwrap();
sqlx::query!(
r#"
UPDATE requests
SET state = 'completed',
response_status = 200,
response_body = '{"cached":true}',
response_size = 16,
completed_at = NOW()
WHERE batch_id = $1
"#,
*batch.id as Uuid,
)
.execute(&pool)
.await
.unwrap();
let file1 = manager.get_file(output_file_id).await.unwrap();
assert!(file1.size_finalized);
let finalized_size = file1.size_bytes;
sqlx::query!(
r#"
UPDATE requests
SET response_size = 999999
WHERE batch_id = $1
"#,
*batch.id as Uuid,
)
.execute(&pool)
.await
.unwrap();
let file2 = manager.get_file(output_file_id).await.unwrap();
assert!(file2.size_finalized);
assert_eq!(
file2.size_bytes, finalized_size,
"Should use cached finalized value, not recalculate"
);
let files = manager
.list_files(crate::batch::FileFilter {
purpose: Some(crate::batch::Purpose::BatchOutput.to_string()),
..Default::default()
})
.await
.unwrap();
let listed_file = files.iter().find(|f| f.id == output_file_id).unwrap();
assert_eq!(
listed_file.size_bytes, finalized_size,
"List should use cached value too"
);
}
#[sqlx::test]
async fn test_normal_files_finalized_immediately_no_calculation(pool: sqlx::PgPool) {
let http_client = Arc::new(MockHttpClient::new());
let manager = PostgresRequestManager::with_client(
TestDbPools::new(pool.clone()).await.unwrap(),
http_client,
);
let file_id = manager
.create_file(
"normal-file-test".to_string(),
Some("A normal input file".to_string()),
vec![
RequestTemplateInput {
custom_id: None,
endpoint: "https://api.example.com".to_string(),
method: "POST".to_string(),
path: "/test".to_string(),
body: "{}".to_string(),
model: "test".to_string(),
api_key: "key".to_string(),
},
RequestTemplateInput {
custom_id: None,
endpoint: "https://api.example.com".to_string(),
method: "POST".to_string(),
path: "/test".to_string(),
body: "{}".to_string(),
model: "test".to_string(),
api_key: "key".to_string(),
},
],
)
.await
.unwrap();
let file = manager.get_file(file_id).await.unwrap();
assert!(
file.size_finalized,
"Normal input files should be finalized immediately"
);
assert_eq!(file.size_bytes, 0, "Input files have size 0 by default");
let db_file = sqlx::query!(
"SELECT size_finalized, purpose FROM files WHERE id = $1",
*file_id as Uuid
)
.fetch_one(&pool)
.await
.unwrap();
assert!(db_file.size_finalized, "Should be finalized in DB");
assert!(
db_file.purpose.is_none() || db_file.purpose.as_deref() == Some("batch"),
"Should not be a virtual output/error file"
);
}
#[sqlx::test]
async fn test_sla_based_claim_priority(pool: sqlx::PgPool) {
let http_client = Arc::new(MockHttpClient::new());
let manager = PostgresRequestManager::with_client(
TestDbPools::new(pool.clone()).await.unwrap(),
http_client,
);
let file3 = manager
.create_file(
"no-sla-batch".to_string(),
None,
vec![RequestTemplateInput {
custom_id: Some("no-sla-1".to_string()),
endpoint: "https://api.example.com".to_string(),
method: "POST".to_string(),
path: "/test".to_string(),
body: r#"{"no_sla":true}"#.to_string(),
model: "test".to_string(),
api_key: "key".to_string(),
}],
)
.await
.unwrap();
let batch3 = manager
.create_batch(crate::batch::BatchInput {
file_id: file3,
endpoint: "/v1/chat/completions".to_string(),
completion_window: "24h".to_string(),
metadata: None,
created_by: None,
api_key_id: None,
api_key: None,
total_requests: None,
})
.await
.unwrap();
tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
let file2 = manager
.create_file(
"medium-batch".to_string(),
None,
vec![RequestTemplateInput {
custom_id: Some("medium-1".to_string()),
endpoint: "https://api.example.com".to_string(),
method: "POST".to_string(),
path: "/test".to_string(),
body: r#"{"medium":true}"#.to_string(),
model: "test".to_string(),
api_key: "key".to_string(),
}],
)
.await
.unwrap();
let batch2 = manager
.create_batch(crate::batch::BatchInput {
file_id: file2,
endpoint: "/v1/chat/completions".to_string(),
completion_window: "24h".to_string(),
metadata: None,
created_by: None,
api_key_id: None,
api_key: None,
total_requests: None,
})
.await
.unwrap();
sqlx::query!(
"UPDATE batches SET expires_at = NOW() + INTERVAL '2 hours' WHERE id = $1",
*batch2.id as Uuid
)
.execute(&pool)
.await
.unwrap();
tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
let file1 = manager
.create_file(
"urgent-batch".to_string(),
None,
vec![RequestTemplateInput {
custom_id: Some("urgent-1".to_string()),
endpoint: "https://api.example.com".to_string(),
method: "POST".to_string(),
path: "/test".to_string(),
body: r#"{"urgent":true}"#.to_string(),
model: "test".to_string(),
api_key: "key".to_string(),
}],
)
.await
.unwrap();
let batch1 = manager
.create_batch(crate::batch::BatchInput {
file_id: file1,
endpoint: "/v1/chat/completions".to_string(),
completion_window: "24h".to_string(),
metadata: None,
created_by: None,
api_key_id: None,
api_key: None,
total_requests: None,
})
.await
.unwrap();
sqlx::query!(
"UPDATE batches SET expires_at = NOW() + INTERVAL '30 minutes' WHERE id = $1",
*batch1.id as Uuid
)
.execute(&pool)
.await
.unwrap();
let daemon_id = DaemonId::from(Uuid::new_v4());
let capacity = HashMap::from([("test".to_string(), 10)]);
let claimed = manager
.claim_requests(1, daemon_id, &capacity, &HashMap::new())
.await
.expect("Failed to claim requests");
assert_eq!(claimed.len(), 1);
assert_eq!(
claimed[0].data.batch_id, batch1.id,
"First claim should be from most urgent batch (30 min expiration)"
);
assert_eq!(claimed[0].data.custom_id, Some("urgent-1".to_string()));
let claimed2 = manager
.claim_requests(1, daemon_id, &capacity, &HashMap::new())
.await
.expect("Failed to claim requests");
assert_eq!(claimed2.len(), 1);
assert_eq!(
claimed2[0].data.batch_id, batch2.id,
"Second claim should be from medium priority batch (2 hour expiration)"
);
assert_eq!(claimed2[0].data.custom_id, Some("medium-1".to_string()));
let claimed3 = manager
.claim_requests(1, daemon_id, &capacity, &HashMap::new())
.await
.expect("Failed to claim requests");
assert_eq!(claimed3.len(), 1);
assert_eq!(
claimed3[0].data.batch_id, batch3.id,
"Third claim should be from no-SLA batch (NULL expiration)"
);
assert_eq!(claimed3[0].data.custom_id, Some("no-sla-1".to_string()));
}
#[sqlx::test]
async fn test_claim_drains_earliest_batch_first(pool: sqlx::PgPool) {
let http_client = Arc::new(MockHttpClient::new());
let manager = PostgresRequestManager::with_client(
TestDbPools::new(pool.clone()).await.unwrap(),
http_client,
);
let mut filler_ids = Vec::new();
for i in 0..9 {
let file = manager
.create_file(
format!("filler-{i}"),
None,
(0..3)
.map(|j| RequestTemplateInput {
custom_id: Some(format!("filler-{i}-{j}")),
endpoint: "https://api.example.com".to_string(),
method: "POST".to_string(),
path: "/test".to_string(),
body: "{}".to_string(),
model: "test-fifo".to_string(),
api_key: "key".to_string(),
})
.collect(),
)
.await
.unwrap();
let batch = manager
.create_batch(crate::batch::BatchInput {
file_id: file,
endpoint: "/v1/chat/completions".to_string(),
completion_window: "24h".to_string(),
metadata: None,
created_by: None,
api_key_id: None,
api_key: None,
total_requests: None,
})
.await
.unwrap();
sqlx::query!(
"UPDATE batches SET expires_at = NOW() + INTERVAL '6 hours' WHERE id = $1",
*batch.id as Uuid
)
.execute(&pool)
.await
.unwrap();
filler_ids.push(batch.id);
}
let urgent_file = manager
.create_file(
"urgent".to_string(),
None,
(0..3)
.map(|j| RequestTemplateInput {
custom_id: Some(format!("urgent-{j}")),
endpoint: "https://api.example.com".to_string(),
method: "POST".to_string(),
path: "/test".to_string(),
body: "{}".to_string(),
model: "test-fifo".to_string(),
api_key: "key".to_string(),
})
.collect(),
)
.await
.unwrap();
let urgent_batch = manager
.create_batch(crate::batch::BatchInput {
file_id: urgent_file,
endpoint: "/v1/chat/completions".to_string(),
completion_window: "24h".to_string(),
metadata: None,
created_by: None,
api_key_id: None,
api_key: None,
total_requests: None,
})
.await
.unwrap();
sqlx::query!(
"UPDATE batches SET expires_at = NOW() + INTERVAL '30 minutes' WHERE id = $1",
*urgent_batch.id as Uuid
)
.execute(&pool)
.await
.unwrap();
let daemon_id = DaemonId::from(Uuid::new_v4());
let capacity = HashMap::from([("test-fifo".to_string(), 3)]);
let claimed = manager
.claim_requests(3, daemon_id, &capacity, &HashMap::new())
.await
.expect("Failed to claim requests");
assert_eq!(claimed.len(), 3);
for req in &claimed {
assert_eq!(
req.data.batch_id, urgent_batch.id,
"All claimed requests should come from the urgent batch (earliest expires_at), \
but got one from a filler batch — indicates non-FIFO ordering"
);
}
}
#[sqlx::test]
async fn test_per_user_fair_scheduling(pool: sqlx::PgPool) {
let http_client = Arc::new(MockHttpClient::new());
let manager = PostgresRequestManager::with_client(
TestDbPools::new(pool.clone()).await.unwrap(),
http_client,
);
for user in &["user-a", "user-b", "user-c"] {
let file_id = manager
.create_file(
format!("{}-file", user),
None,
vec![
RequestTemplateInput {
custom_id: Some(format!("{}-req-1", user)),
endpoint: "https://api.example.com".to_string(),
method: "POST".to_string(),
path: "/test".to_string(),
body: "{}".to_string(),
model: "fair-test".to_string(),
api_key: "key".to_string(),
},
RequestTemplateInput {
custom_id: Some(format!("{}-req-2", user)),
endpoint: "https://api.example.com".to_string(),
method: "POST".to_string(),
path: "/test".to_string(),
body: "{}".to_string(),
model: "fair-test".to_string(),
api_key: "key".to_string(),
},
],
)
.await
.unwrap();
let batch = manager
.create_batch(BatchInput {
file_id,
endpoint: "/v1/chat/completions".to_string(),
completion_window: "24h".to_string(),
metadata: None,
created_by: Some(user.to_string()),
api_key_id: None,
api_key: None,
total_requests: None,
})
.await
.unwrap();
}
let daemon_id = DaemonId::from(Uuid::new_v4());
let capacity = HashMap::from([("fair-test".to_string(), 6)]);
let claimed = manager
.claim_requests(6, daemon_id, &capacity, &HashMap::new())
.await
.expect("Failed to claim requests (cold start)");
assert_eq!(
claimed.len(),
6,
"Should claim all 6 requests on cold start"
);
sqlx::query!("UPDATE requests SET state = 'pending', daemon_id = NULL, claimed_at = NULL WHERE state = 'claimed'")
.execute(&pool)
.await
.unwrap();
let user_counts = HashMap::from([("user-a".to_string(), 5usize)]);
let claimed = manager
.claim_requests(6, daemon_id, &capacity, &user_counts)
.await
.expect("Failed to claim requests (with user counts)");
assert_eq!(claimed.len(), 6, "Should still claim all 6 requests");
let mut per_user: HashMap<String, Vec<usize>> = HashMap::new();
for (i, req) in claimed.iter().enumerate() {
per_user
.entry(req.data.created_by.clone())
.or_default()
.push(i);
}
let user_a_first = per_user.get("user-a").map(|v| v[0]).unwrap_or(0);
let user_b_first = per_user.get("user-b").map(|v| v[0]).unwrap_or(usize::MAX);
let user_c_first = per_user.get("user-c").map(|v| v[0]).unwrap_or(usize::MAX);
assert!(
user_b_first < user_a_first,
"user-b (0 active) should be prioritised over user-a (5 active), \
but user-b first index={} vs user-a first index={}",
user_b_first,
user_a_first
);
assert!(
user_c_first < user_a_first,
"user-c (0 active) should be prioritised over user-a (5 active), \
but user-c first index={} vs user-a first index={}",
user_c_first,
user_a_first
);
}
#[sqlx::test]
async fn test_per_user_deadline_ordering_preserved(pool: sqlx::PgPool) {
let http_client = Arc::new(MockHttpClient::new());
let manager = PostgresRequestManager::with_client(
TestDbPools::new(pool.clone()).await.unwrap(),
http_client,
);
let file_id_urgent = manager
.create_file(
"urgent-file".to_string(),
None,
vec![RequestTemplateInput {
custom_id: Some("urgent-req".to_string()),
endpoint: "https://api.example.com".to_string(),
method: "POST".to_string(),
path: "/test".to_string(),
body: "{}".to_string(),
model: "deadline-test".to_string(),
api_key: "key".to_string(),
}],
)
.await
.unwrap();
let file_id_relaxed = manager
.create_file(
"relaxed-file".to_string(),
None,
vec![RequestTemplateInput {
custom_id: Some("relaxed-req".to_string()),
endpoint: "https://api.example.com".to_string(),
method: "POST".to_string(),
path: "/test".to_string(),
body: "{}".to_string(),
model: "deadline-test".to_string(),
api_key: "key".to_string(),
}],
)
.await
.unwrap();
let urgent_batch = manager
.create_batch(BatchInput {
file_id: file_id_urgent,
endpoint: "/v1/chat/completions".to_string(),
completion_window: "1h".to_string(),
metadata: None,
created_by: Some("same-user".to_string()),
api_key_id: None,
api_key: None,
total_requests: None,
})
.await
.unwrap();
let _relaxed_batch = manager
.create_batch(BatchInput {
file_id: file_id_relaxed,
endpoint: "/v1/chat/completions".to_string(),
completion_window: "7d".to_string(),
metadata: None,
created_by: Some("same-user".to_string()),
api_key_id: None,
api_key: None,
total_requests: None,
})
.await
.unwrap();
let daemon_id = DaemonId::from(Uuid::new_v4());
let capacity = HashMap::from([("deadline-test".to_string(), 2)]);
let user_counts = HashMap::from([("same-user".to_string(), 0usize)]);
let claimed = manager
.claim_requests(1, daemon_id, &capacity, &user_counts)
.await
.expect("Failed to claim requests");
assert_eq!(claimed.len(), 1);
assert_eq!(
claimed[0].data.batch_id, urgent_batch.id,
"Urgent batch (earlier deadline) should be claimed first when user priority is equal"
);
}
#[sqlx::test]
async fn test_urgency_weighted_scheduling(pool: sqlx::PgPool) {
let http_client = Arc::new(MockHttpClient::new());
async fn setup_user_batch(
manager: &PostgresRequestManager<TestDbPools, MockHttpClient>,
user: &str,
completion_window: &str,
) -> BatchId {
let file_id = manager
.create_file(
format!("{}-file", user),
None,
vec![RequestTemplateInput {
custom_id: Some(format!("{}-req", user)),
endpoint: "https://api.example.com".to_string(),
method: "POST".to_string(),
path: "/test".to_string(),
body: "{}".to_string(),
model: "urgency-test".to_string(),
api_key: "key".to_string(),
}],
)
.await
.unwrap();
let batch = manager
.create_batch(BatchInput {
file_id,
endpoint: "/v1/chat/completions".to_string(),
completion_window: completion_window.to_string(),
metadata: None,
created_by: Some(user.to_string()),
api_key_id: None,
api_key: None,
total_requests: None,
})
.await
.unwrap();
batch.id
}
let manager = PostgresRequestManager::with_client(
TestDbPools::new(pool.clone()).await.unwrap(),
http_client.clone(),
)
.with_config(DaemonConfig {
urgency_weight: 0.5,
..DaemonConfig::default()
});
let batch_a = setup_user_batch(&manager, "user-a", "1h").await;
let _batch_b = setup_user_batch(&manager, "user-b", "24h").await;
let daemon_id = DaemonId::from(Uuid::new_v4());
let capacity = HashMap::from([("urgency-test".to_string(), 2)]);
let user_counts = HashMap::from([
("user-a".to_string(), 1usize),
("user-b".to_string(), 1usize),
]);
let claimed = manager
.claim_requests(1, daemon_id, &capacity, &user_counts)
.await
.expect("Failed to claim with urgency weight");
assert_eq!(claimed.len(), 1);
assert_eq!(
claimed[0].data.batch_id, batch_a,
"With urgency_weight=0.5 and equal user activity, \
the 1hr SLA batch should be claimed before the 24hr batch"
);
sqlx::query!(
"UPDATE requests SET state = 'pending', daemon_id = NULL, claimed_at = NULL WHERE state = 'claimed'"
)
.execute(&pool)
.await
.unwrap();
let manager_no_urgency = PostgresRequestManager::with_client(
TestDbPools::new(pool.clone()).await.unwrap(),
http_client.clone(),
)
.with_config(DaemonConfig {
urgency_weight: 0.0,
..DaemonConfig::default()
});
let user_counts_skewed = HashMap::from([
("user-a".to_string(), 5usize),
("user-b".to_string(), 0usize),
]);
let claimed = manager_no_urgency
.claim_requests(1, daemon_id, &capacity, &user_counts_skewed)
.await
.expect("Failed to claim without urgency weight");
assert_eq!(claimed.len(), 1);
assert_eq!(
claimed[0].data.created_by, "user-b",
"With urgency_weight=0.0, user-b (0 active) should beat user-a (5 active) \
despite user-a having a more urgent 1hr SLA"
);
}
#[sqlx::test]
async fn test_empty_virtual_files_finalized_at_zero(pool: sqlx::PgPool) {
let http_client = Arc::new(MockHttpClient::new());
let manager = PostgresRequestManager::with_client(
TestDbPools::new(pool.clone()).await.unwrap(),
http_client,
);
let file_id = manager
.create_file(
"all-fail-test".to_string(),
None,
vec![
RequestTemplateInput {
custom_id: None,
endpoint: "https://api.example.com".to_string(),
method: "POST".to_string(),
path: "/test".to_string(),
body: "{}".to_string(),
model: "test".to_string(),
api_key: "key".to_string(),
},
RequestTemplateInput {
custom_id: None,
endpoint: "https://api.example.com".to_string(),
method: "POST".to_string(),
path: "/test".to_string(),
body: "{}".to_string(),
model: "test".to_string(),
api_key: "key".to_string(),
},
],
)
.await
.unwrap();
let batch = manager
.create_batch(BatchInput {
file_id,
endpoint: "/v1/test".to_string(),
completion_window: "24h".to_string(),
metadata: None,
created_by: None,
api_key_id: None,
api_key: None,
total_requests: None,
})
.await
.unwrap();
let output_file_id = batch.output_file_id.unwrap();
let error_file_id = batch.error_file_id.unwrap();
sqlx::query!(
r#"
UPDATE requests
SET state = 'failed',
error = '{"error":"all failed"}',
response_size = 22,
failed_at = NOW()
WHERE batch_id = $1
"#,
*batch.id as Uuid,
)
.execute(&pool)
.await
.unwrap();
let output_file = manager.get_file(output_file_id).await.unwrap();
assert!(
output_file.size_finalized,
"Empty output file should be finalized"
);
assert_eq!(
output_file.size_bytes, 0,
"Output file with no completions should have size 0"
);
let error_file = manager.get_file(error_file_id).await.unwrap();
assert!(error_file.size_finalized, "Error file should be finalized");
assert!(
error_file.size_bytes > 0,
"Error file should have size > 0 (2 failed requests)"
);
}
#[sqlx::test]
async fn test_retry_failed_requests_for_batch(pool: sqlx::PgPool) {
let http_client = Arc::new(MockHttpClient::new());
let manager = Arc::new(PostgresRequestManager::with_client(
TestDbPools::new(pool.clone()).await.unwrap(),
http_client,
));
let file_id = manager
.create_file(
"retry-batch-test".to_string(),
None,
vec![
RequestTemplateInput {
custom_id: Some("req-1".to_string()),
endpoint: "https://api.example.com".to_string(),
method: "POST".to_string(),
path: "/test".to_string(),
body: "{}".to_string(),
model: "test".to_string(),
api_key: "key".to_string(),
},
RequestTemplateInput {
custom_id: Some("req-2".to_string()),
endpoint: "https://api.example.com".to_string(),
method: "POST".to_string(),
path: "/test".to_string(),
body: "{}".to_string(),
model: "test".to_string(),
api_key: "key".to_string(),
},
RequestTemplateInput {
custom_id: Some("req-3".to_string()),
endpoint: "https://api.example.com".to_string(),
method: "POST".to_string(),
path: "/test".to_string(),
body: "{}".to_string(),
model: "test".to_string(),
api_key: "key".to_string(),
},
],
)
.await
.unwrap();
let batch = manager
.create_batch(crate::batch::BatchInput {
file_id,
endpoint: "/v1/chat/completions".to_string(),
completion_window: "24h".to_string(),
metadata: None,
created_by: None,
api_key_id: None,
api_key: None,
total_requests: None,
})
.await
.unwrap();
sqlx::query!(
r#"
UPDATE requests
SET state = 'failed',
error = 'test error',
failed_at = NOW(),
retry_attempt = 3,
daemon_id = '00000000-0000-0000-0000-000000000001',
claimed_at = NOW() - INTERVAL '1 hour',
started_at = NOW() - INTERVAL '30 minutes'
WHERE batch_id = $1
AND id IN (
SELECT id FROM requests WHERE batch_id = $1 ORDER BY created_at LIMIT 2
)
"#,
*batch.id as Uuid,
)
.execute(&pool)
.await
.unwrap();
sqlx::query!(
r#"
UPDATE batches
SET failed_at = NOW(),
finalizing_at = NOW(),
notification_sent_at = NOW()
WHERE id = $1
"#,
*batch.id as Uuid,
)
.execute(&pool)
.await
.unwrap();
let requests_before = manager.get_batch_requests(batch.id).await.unwrap();
let failed_count = requests_before
.iter()
.filter(|r| matches!(r, AnyRequest::Failed(_)))
.count();
let pending_count = requests_before
.iter()
.filter(|r| matches!(r, AnyRequest::Pending(_)))
.count();
assert_eq!(failed_count, 2);
assert_eq!(pending_count, 1);
let retried = manager
.retry_failed_requests_for_batch(batch.id)
.await
.unwrap();
assert_eq!(retried, 2, "Should have retried 2 failed requests");
let requests_after = manager.get_batch_requests(batch.id).await.unwrap();
let pending_after = requests_after
.iter()
.filter(|r| matches!(r, AnyRequest::Pending(_)))
.count();
assert_eq!(
pending_after, 3,
"All requests should be pending after retry"
);
let batch_after = sqlx::query!(
r#"
SELECT completed_at, failed_at, finalizing_at, notification_sent_at
FROM batches WHERE id = $1
"#,
*batch.id as Uuid,
)
.fetch_one(&pool)
.await
.unwrap();
assert!(
batch_after.completed_at.is_none(),
"completed_at should be cleared after retry"
);
assert!(
batch_after.failed_at.is_none(),
"failed_at should be cleared after retry"
);
assert!(
batch_after.finalizing_at.is_none(),
"finalizing_at should be cleared after retry"
);
assert!(
batch_after.notification_sent_at.is_none(),
"notification_sent_at should be cleared after retry"
);
for req in &requests_after {
if let AnyRequest::Pending(r) = req {
assert_eq!(
r.state.retry_attempt, 0,
"retry_attempt should be reset to 0"
);
}
}
let cleared_check = sqlx::query!(
r#"
SELECT COUNT(*) as "count!"
FROM requests
WHERE batch_id = $1
AND (daemon_id IS NOT NULL OR claimed_at IS NOT NULL OR started_at IS NOT NULL)
"#,
*batch.id as Uuid,
)
.fetch_one(&pool)
.await
.unwrap();
assert_eq!(
cleared_check.count, 0,
"daemon_id, claimed_at, and started_at should all be NULL after retry"
);
let retried_again = manager
.retry_failed_requests_for_batch(batch.id)
.await
.unwrap();
assert_eq!(retried_again, 0, "No failed requests to retry");
}
#[sqlx::test]
async fn test_purge_orphaned_templates_after_file_delete(pool: sqlx::PgPool) {
let http_client = Arc::new(MockHttpClient::new());
let manager = PostgresRequestManager::with_client(
TestDbPools::new(pool.clone()).await.unwrap(),
http_client,
);
let file_id = manager
.create_file(
"purge-templates-test".to_string(),
None,
vec![
RequestTemplateInput {
custom_id: None,
endpoint: "https://api.example.com".to_string(),
method: "POST".to_string(),
path: "/test".to_string(),
body: r#"{"n":1}"#.to_string(),
model: "test".to_string(),
api_key: "key".to_string(),
},
RequestTemplateInput {
custom_id: None,
endpoint: "https://api.example.com".to_string(),
method: "POST".to_string(),
path: "/test".to_string(),
body: r#"{"n":2}"#.to_string(),
model: "test".to_string(),
api_key: "key".to_string(),
},
],
)
.await
.unwrap();
let count_before: i64 = sqlx::query_scalar!(
"SELECT count(*) as \"count!\" FROM request_templates WHERE file_id = $1",
*file_id as Uuid,
)
.fetch_one(&pool)
.await
.unwrap();
assert_eq!(count_before, 2);
manager.delete_file(file_id).await.unwrap();
let deleted = manager.purge_orphaned_rows(1000).await.unwrap();
assert!(deleted >= 2, "Should have deleted at least 2 templates");
let count_after: i64 = sqlx::query_scalar!(
"SELECT count(*) as \"count!\" FROM request_templates WHERE file_id = $1",
*file_id as Uuid,
)
.fetch_one(&pool)
.await
.unwrap();
assert_eq!(count_after, 0);
let deleted_again = manager.purge_orphaned_rows(1000).await.unwrap();
assert_eq!(deleted_again, 0);
}
#[sqlx::test]
async fn test_purge_orphaned_requests_after_batch_delete(pool: sqlx::PgPool) {
let http_client = Arc::new(MockHttpClient::new());
let manager = PostgresRequestManager::with_client(
TestDbPools::new(pool.clone()).await.unwrap(),
http_client,
);
let file_id = manager
.create_file(
"purge-requests-test".to_string(),
None,
vec![
RequestTemplateInput {
custom_id: None,
endpoint: "https://api.example.com".to_string(),
method: "POST".to_string(),
path: "/test".to_string(),
body: r#"{"n":1}"#.to_string(),
model: "test".to_string(),
api_key: "key".to_string(),
},
RequestTemplateInput {
custom_id: None,
endpoint: "https://api.example.com".to_string(),
method: "POST".to_string(),
path: "/test".to_string(),
body: r#"{"n":2}"#.to_string(),
model: "test".to_string(),
api_key: "key".to_string(),
},
],
)
.await
.unwrap();
let batch = manager
.create_batch(crate::batch::BatchInput {
file_id,
endpoint: "/v1/chat/completions".to_string(),
completion_window: "24h".to_string(),
metadata: None,
created_by: None,
api_key_id: None,
api_key: None,
total_requests: None,
})
.await
.unwrap();
let count_before: i64 = sqlx::query_scalar!(
"SELECT count(*) as \"count!\" FROM requests WHERE batch_id = $1",
*batch.id as Uuid,
)
.fetch_one(&pool)
.await
.unwrap();
assert_eq!(count_before, 2);
manager.delete_batch(batch.id).await.unwrap();
let deleted = manager.purge_orphaned_rows(1000).await.unwrap();
assert!(deleted >= 2, "Should have deleted at least 2 requests");
let count_after: i64 = sqlx::query_scalar!(
"SELECT count(*) as \"count!\" FROM requests WHERE batch_id = $1",
*batch.id as Uuid,
)
.fetch_one(&pool)
.await
.unwrap();
assert_eq!(count_after, 0);
}
#[sqlx::test]
async fn test_purge_does_not_delete_active_rows(pool: sqlx::PgPool) {
let http_client = Arc::new(MockHttpClient::new());
let manager = PostgresRequestManager::with_client(
TestDbPools::new(pool.clone()).await.unwrap(),
http_client,
);
let file_id = manager
.create_file(
"purge-active-test".to_string(),
None,
vec![RequestTemplateInput {
custom_id: None,
endpoint: "https://api.example.com".to_string(),
method: "POST".to_string(),
path: "/test".to_string(),
body: r#"{"n":1}"#.to_string(),
model: "test".to_string(),
api_key: "key".to_string(),
}],
)
.await
.unwrap();
let batch = manager
.create_batch(crate::batch::BatchInput {
file_id,
endpoint: "/v1/chat/completions".to_string(),
completion_window: "24h".to_string(),
metadata: None,
created_by: None,
api_key_id: None,
api_key: None,
total_requests: None,
})
.await
.unwrap();
let deleted = manager.purge_orphaned_rows(1000).await.unwrap();
assert_eq!(deleted, 0, "Should not delete any active rows");
let template_count: i64 = sqlx::query_scalar!(
"SELECT count(*) as \"count!\" FROM request_templates WHERE file_id = $1",
*file_id as Uuid,
)
.fetch_one(&pool)
.await
.unwrap();
assert_eq!(template_count, 1);
let request_count: i64 = sqlx::query_scalar!(
"SELECT count(*) as \"count!\" FROM requests WHERE batch_id = $1",
*batch.id as Uuid,
)
.fetch_one(&pool)
.await
.unwrap();
assert_eq!(request_count, 1);
}
#[sqlx::test]
async fn test_purge_returns_zero_when_empty(pool: sqlx::PgPool) {
let http_client = Arc::new(MockHttpClient::new());
let manager = PostgresRequestManager::with_client(
TestDbPools::new(pool.clone()).await.unwrap(),
http_client,
);
let deleted = manager.purge_orphaned_rows(1000).await.unwrap();
assert_eq!(deleted, 0);
}
#[sqlx::test]
async fn test_purge_respects_batch_size(pool: sqlx::PgPool) {
let http_client = Arc::new(MockHttpClient::new());
let manager = PostgresRequestManager::with_client(
TestDbPools::new(pool.clone()).await.unwrap(),
http_client,
);
let templates: Vec<RequestTemplateInput> = (0..10)
.map(|i| RequestTemplateInput {
custom_id: None,
endpoint: "https://api.example.com".to_string(),
method: "POST".to_string(),
path: "/test".to_string(),
body: format!(r#"{{"n":{}}}"#, i),
model: "test".to_string(),
api_key: "key".to_string(),
})
.collect();
let file_id = manager
.create_file("purge-batch-size-test".to_string(), None, templates)
.await
.unwrap();
manager.delete_file(file_id).await.unwrap();
let deleted_first = manager.purge_orphaned_rows(3).await.unwrap();
assert_eq!(deleted_first, 3, "Should delete exactly 3 templates");
let remaining: i64 = sqlx::query_scalar!(
"SELECT count(*) as \"count!\" FROM request_templates WHERE file_id = $1",
*file_id as Uuid,
)
.fetch_one(&pool)
.await
.unwrap();
assert_eq!(remaining, 7);
let deleted_second = manager.purge_orphaned_rows(3).await.unwrap();
assert_eq!(deleted_second, 3);
let deleted_third = manager.purge_orphaned_rows(3).await.unwrap();
assert_eq!(deleted_third, 3);
let deleted_fourth = manager.purge_orphaned_rows(3).await.unwrap();
assert_eq!(deleted_fourth, 1, "Only 1 remaining");
let deleted_fifth = manager.purge_orphaned_rows(3).await.unwrap();
assert_eq!(deleted_fifth, 0, "Nothing left to purge");
}
#[sqlx::test]
async fn test_purge_deletes_templates_after_file_delete_without_waiting_for_batch(
pool: sqlx::PgPool,
) {
let http_client = Arc::new(MockHttpClient::new());
let manager = PostgresRequestManager::with_client(
TestDbPools::new(pool.clone()).await.unwrap(),
http_client,
);
let file_id = manager
.create_file(
"purge-safety-guard-test".to_string(),
None,
vec![RequestTemplateInput {
custom_id: None,
endpoint: "https://api.example.com".to_string(),
method: "POST".to_string(),
path: "/test".to_string(),
body: r#"{"n":1}"#.to_string(),
model: "test".to_string(),
api_key: "key".to_string(),
}],
)
.await
.unwrap();
let batch = manager
.create_batch(crate::batch::BatchInput {
file_id,
endpoint: "/v1/chat/completions".to_string(),
completion_window: "24h".to_string(),
metadata: None,
created_by: None,
api_key_id: None,
api_key: None,
total_requests: None,
})
.await
.unwrap();
manager.delete_file(file_id).await.unwrap();
let deleted = manager.purge_orphaned_rows(1000).await.unwrap();
assert!(deleted >= 1, "Should delete orphaned templates");
let template_count: i64 = sqlx::query_scalar!(
"SELECT count(*) as \"count!\" FROM request_templates WHERE file_id = $1",
*file_id as Uuid,
)
.fetch_one(&pool)
.await
.unwrap();
assert_eq!(template_count, 0);
let batch_after = manager.get_batch(batch.id).await.unwrap();
assert!(
batch_after.cancelling_at.is_some(),
"Batch should be cancelled"
);
assert_eq!(
batch_after.file_id, None,
"Batch file_id should be NULL after file deletion"
);
let request_count: i64 = sqlx::query_scalar!(
"SELECT count(*) as \"count!\" FROM requests WHERE batch_id = $1",
*batch.id as Uuid,
)
.fetch_one(&pool)
.await
.unwrap();
assert_eq!(request_count, 1, "Request should still exist");
let null_template_count: i64 = sqlx::query_scalar!(
"SELECT count(*) as \"count!\" FROM requests WHERE batch_id = $1 AND template_id IS NULL",
*batch.id as Uuid,
)
.fetch_one(&pool)
.await
.unwrap();
assert_eq!(
null_template_count, 1,
"Request template_id should be NULL after template deletion"
);
let input_file = manager.get_file(file_id).await;
assert!(
input_file.is_err(),
"Input file should not be accessible after deletion"
);
}
#[sqlx::test]
async fn test_purge_batch_size_applies_independently_to_requests_and_templates(
pool: sqlx::PgPool,
) {
let http_client = Arc::new(MockHttpClient::new());
let manager = PostgresRequestManager::with_client(
TestDbPools::new(pool.clone()).await.unwrap(),
http_client,
);
let templates: Vec<RequestTemplateInput> = (0..5)
.map(|i| RequestTemplateInput {
custom_id: None,
endpoint: "https://api.example.com".to_string(),
method: "POST".to_string(),
path: "/test".to_string(),
body: format!(r#"{{"n":{}}}"#, i),
model: "test".to_string(),
api_key: "key".to_string(),
})
.collect();
let file_id = manager
.create_file("purge-independent-limit-test".to_string(), None, templates)
.await
.unwrap();
let batch = manager
.create_batch(crate::batch::BatchInput {
file_id,
endpoint: "/v1/chat/completions".to_string(),
completion_window: "24h".to_string(),
metadata: None,
created_by: None,
api_key_id: None,
api_key: None,
total_requests: None,
})
.await
.unwrap();
manager.delete_batch(batch.id).await.unwrap();
manager.delete_file(file_id).await.unwrap();
let deleted = manager.purge_orphaned_rows(3).await.unwrap();
assert!(
deleted > 3,
"batch_size should apply independently: expected >3, got {}",
deleted
);
let mut total_deleted = deleted;
loop {
let d = manager.purge_orphaned_rows(3).await.unwrap();
if d == 0 {
break;
}
total_deleted += d;
}
assert_eq!(
total_deleted, 10,
"Should delete all 5 requests + 5 templates"
);
}
#[sqlx::test]
async fn test_pending_request_counts_by_model_and_window_basic(pool: sqlx::PgPool) {
use chrono::Duration;
let http_client = Arc::new(MockHttpClient::new());
let manager = PostgresRequestManager::with_client(
TestDbPools::new(pool.clone()).await.unwrap(),
http_client,
);
let file_id_a = manager
.create_file(
"file-a".to_string(),
None,
vec![
RequestTemplateInput {
custom_id: Some("a1".to_string()),
endpoint: "/v1/chat/completions".to_string(),
method: "POST".to_string(),
path: "/v1/chat/completions".to_string(),
body: r#"{"input":"a1"}"#.to_string(),
model: "model-a".to_string(),
api_key: "k".to_string(),
},
RequestTemplateInput {
custom_id: Some("a2".to_string()),
endpoint: "/v1/chat/completions".to_string(),
method: "POST".to_string(),
path: "/v1/chat/completions".to_string(),
body: r#"{"input":"a2"}"#.to_string(),
model: "model-a".to_string(),
api_key: "k".to_string(),
},
],
)
.await
.unwrap();
let batch_a = manager
.create_batch(BatchInput {
file_id: file_id_a,
endpoint: "/v1/chat/completions".to_string(),
completion_window: "24h".to_string(),
metadata: None,
created_by: None,
api_key_id: None,
api_key: None,
total_requests: None,
})
.await
.unwrap();
let file_id_b = manager
.create_file(
"file-b".to_string(),
None,
vec![
RequestTemplateInput {
custom_id: Some("b1".to_string()),
endpoint: "/v1/chat/completions".to_string(),
method: "POST".to_string(),
path: "/v1/chat/completions".to_string(),
body: r#"{"input":"b1"}"#.to_string(),
model: "model-b".to_string(),
api_key: "k".to_string(),
},
RequestTemplateInput {
custom_id: Some("b2".to_string()),
endpoint: "/v1/chat/completions".to_string(),
method: "POST".to_string(),
path: "/v1/chat/completions".to_string(),
body: r#"{"input":"b2"}"#.to_string(),
model: "model-b".to_string(),
api_key: "k".to_string(),
},
],
)
.await
.unwrap();
let batch_b = manager
.create_batch(BatchInput {
file_id: file_id_b,
endpoint: "/v1/chat/completions".to_string(),
completion_window: "24h".to_string(),
metadata: None,
created_by: None,
api_key_id: None,
api_key: None,
total_requests: None,
})
.await
.unwrap();
let now = Utc::now();
sqlx::query!(
"UPDATE batches SET expires_at = $1 WHERE id = $2",
now + Duration::minutes(30),
*batch_a.id as Uuid
)
.execute(&pool)
.await
.unwrap();
sqlx::query!(
"UPDATE batches SET expires_at = $1 WHERE id = $2",
now + Duration::hours(3),
*batch_b.id as Uuid
)
.execute(&pool)
.await
.unwrap();
let windows = vec![("1h".to_string(), 3600), ("4h".to_string(), 14_400)];
let states = vec!["pending".to_string()];
let model_filter: Vec<String> = vec![];
let counts = manager
.get_pending_request_counts_by_model_and_completion_window(
&windows,
&states,
&model_filter,
false,
)
.await
.unwrap();
assert_eq!(*counts.get("model-a").unwrap().get("1h").unwrap(), 2);
assert_eq!(*counts.get("model-a").unwrap().get("4h").unwrap(), 2);
assert_eq!(*counts.get("model-b").unwrap().get("1h").unwrap(), 0);
assert_eq!(*counts.get("model-b").unwrap().get("4h").unwrap(), 2);
}
#[sqlx::test]
async fn test_pending_request_counts_respects_states_models_and_cancelling(pool: sqlx::PgPool) {
use chrono::Duration;
let http_client = Arc::new(MockHttpClient::new());
let manager = PostgresRequestManager::with_client(
TestDbPools::new(pool.clone()).await.unwrap(),
http_client,
);
let file_id = manager
.create_file(
"file-mixed".to_string(),
None,
vec![
RequestTemplateInput {
custom_id: Some("a1".to_string()),
endpoint: "/v1/chat/completions".to_string(),
method: "POST".to_string(),
path: "/v1/chat/completions".to_string(),
body: r#"{"input":"a1"}"#.to_string(),
model: "model-a".to_string(),
api_key: "k".to_string(),
},
RequestTemplateInput {
custom_id: Some("b1".to_string()),
endpoint: "/v1/chat/completions".to_string(),
method: "POST".to_string(),
path: "/v1/chat/completions".to_string(),
body: r#"{"input":"b1"}"#.to_string(),
model: "model-b".to_string(),
api_key: "k".to_string(),
},
],
)
.await
.unwrap();
let batch = manager
.create_batch(BatchInput {
file_id,
endpoint: "/v1/chat/completions".to_string(),
completion_window: "24h".to_string(),
metadata: None,
created_by: None,
api_key_id: None,
api_key: None,
total_requests: None,
})
.await
.unwrap();
let now = Utc::now();
sqlx::query!(
"UPDATE batches SET expires_at = $1 WHERE id = $2",
now + Duration::minutes(10),
*batch.id as Uuid
)
.execute(&pool)
.await
.unwrap();
sqlx::query!(
r#"
UPDATE requests
SET state = 'claimed',
daemon_id = $1,
claimed_at = NOW()
WHERE batch_id = $2 AND model = 'model-b'
"#,
Uuid::new_v4(),
*batch.id as Uuid
)
.execute(&pool)
.await
.unwrap();
let windows = vec![("15m".to_string(), 900)];
let states = vec!["pending".to_string()];
let model_filter = vec!["model-a".to_string()];
let counts = manager
.get_pending_request_counts_by_model_and_completion_window(
&windows,
&states,
&model_filter,
false,
)
.await
.unwrap();
assert_eq!(*counts.get("model-a").unwrap().get("15m").unwrap(), 1);
assert!(counts.get("model-b").is_none());
let states = vec!["pending".to_string(), "claimed".to_string()];
let model_filter: Vec<String> = vec![];
let counts_all = manager
.get_pending_request_counts_by_model_and_completion_window(
&windows,
&states,
&model_filter,
false,
)
.await
.unwrap();
assert_eq!(*counts_all.get("model-a").unwrap().get("15m").unwrap(), 1);
assert_eq!(*counts_all.get("model-b").unwrap().get("15m").unwrap(), 1);
sqlx::query!(
"UPDATE batches SET cancelling_at = NOW() WHERE id = $1",
*batch.id as Uuid
)
.execute(&pool)
.await
.unwrap();
let counts_cancelled = manager
.get_pending_request_counts_by_model_and_completion_window(
&windows,
&states,
&model_filter,
false,
)
.await
.unwrap();
assert!(counts_cancelled.is_empty());
}
#[sqlx::test]
async fn test_pending_request_counts_empty_inputs(pool: sqlx::PgPool) {
let http_client = Arc::new(MockHttpClient::new());
let manager =
PostgresRequestManager::with_client(TestDbPools::new(pool).await.unwrap(), http_client);
let counts = manager
.get_pending_request_counts_by_model_and_completion_window(
&[],
&["pending".to_string()],
&[],
false,
)
.await
.unwrap();
assert!(counts.is_empty());
let counts = manager
.get_pending_request_counts_by_model_and_completion_window(
&[("1h".to_string(), 3600)],
&[],
&[],
false,
)
.await
.unwrap();
assert!(counts.is_empty());
}
#[sqlx::test]
async fn test_list_batches_filter_by_api_key_id(pool: sqlx::PgPool) {
let http_client = Arc::new(MockHttpClient::new());
let manager = PostgresRequestManager::with_client(
TestDbPools::new(pool.clone()).await.unwrap(),
http_client,
);
let file_id = manager
.create_file(
"api-key-filter-test".to_string(),
None,
vec![RequestTemplateInput {
custom_id: None,
endpoint: "https://api.example.com".to_string(),
method: "POST".to_string(),
path: "/v1/test".to_string(),
body: "{}".to_string(),
model: "test".to_string(),
api_key: "key".to_string(),
}],
)
.await
.unwrap();
let key_a = uuid::Uuid::new_v4();
let key_b = uuid::Uuid::new_v4();
manager
.create_batch(crate::batch::BatchInput {
file_id,
endpoint: "/v1/chat/completions".to_string(),
completion_window: "24h".to_string(),
metadata: None,
created_by: None,
api_key_id: Some(key_a),
api_key: None,
total_requests: None,
})
.await
.unwrap();
manager
.create_batch(crate::batch::BatchInput {
file_id,
endpoint: "/v1/chat/completions".to_string(),
completion_window: "24h".to_string(),
metadata: None,
created_by: None,
api_key_id: Some(key_b),
api_key: None,
total_requests: None,
})
.await
.unwrap();
manager
.create_batch(crate::batch::BatchInput {
file_id,
endpoint: "/v1/chat/completions".to_string(),
completion_window: "24h".to_string(),
metadata: None,
created_by: None,
api_key_id: None,
api_key: None,
total_requests: None,
})
.await
.unwrap();
let results = manager
.list_batches(crate::batch::ListBatchesFilter {
api_key_ids: Some(vec![key_a]),
limit: Some(100),
..Default::default()
})
.await
.unwrap();
assert_eq!(results.len(), 1);
assert_eq!(results[0].api_key_id, Some(key_a));
let results = manager
.list_batches(crate::batch::ListBatchesFilter {
api_key_ids: Some(vec![key_a, key_b]),
limit: Some(100),
..Default::default()
})
.await
.unwrap();
assert_eq!(results.len(), 2);
let returned_keys: std::collections::HashSet<_> =
results.iter().filter_map(|b| b.api_key_id).collect();
assert!(returned_keys.contains(&key_a));
assert!(returned_keys.contains(&key_b));
let results = manager
.list_batches(crate::batch::ListBatchesFilter {
api_key_ids: Some(vec![]),
limit: Some(100),
..Default::default()
})
.await
.unwrap();
assert_eq!(results.len(), 0);
let results = manager
.list_batches(crate::batch::ListBatchesFilter {
limit: Some(100),
..Default::default()
})
.await
.unwrap();
assert_eq!(results.len(), 3);
}
#[sqlx::test]
async fn test_list_batches_filter_by_status_completed(pool: sqlx::PgPool) {
let http_client = Arc::new(MockHttpClient::new());
let manager = PostgresRequestManager::with_client(
TestDbPools::new(pool.clone()).await.unwrap(),
http_client,
);
let file_id = manager
.create_file(
"status-filter-test".to_string(),
None,
vec![RequestTemplateInput {
custom_id: None,
endpoint: "https://api.example.com".to_string(),
method: "POST".to_string(),
path: "/v1/test".to_string(),
body: "{}".to_string(),
model: "test".to_string(),
api_key: "key".to_string(),
}],
)
.await
.unwrap();
let batch = manager
.create_batch(crate::batch::BatchInput {
file_id,
endpoint: "/v1/chat/completions".to_string(),
completion_window: "24h".to_string(),
metadata: None,
created_by: None,
api_key_id: None,
api_key: None,
total_requests: None,
})
.await
.unwrap();
let results = manager
.list_batches(crate::batch::ListBatchesFilter {
status: Some("completed".to_string()),
limit: Some(100),
..Default::default()
})
.await
.unwrap();
assert!(results.iter().all(|b| b.id != batch.id));
let results = manager
.list_batches(crate::batch::ListBatchesFilter {
status: Some("in_progress".to_string()),
limit: Some(100),
..Default::default()
})
.await
.unwrap();
assert!(results.iter().any(|b| b.id == batch.id));
}
#[sqlx::test]
async fn test_list_batches_filter_unknown_status_returns_error(pool: sqlx::PgPool) {
let http_client = Arc::new(MockHttpClient::new());
let manager = PostgresRequestManager::with_client(
TestDbPools::new(pool.clone()).await.unwrap(),
http_client,
);
let result = manager
.list_batches(crate::batch::ListBatchesFilter {
status: Some("nonexistent_status".to_string()),
limit: Some(100),
..Default::default()
})
.await;
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(
err.contains("Unknown batch status filter"),
"Expected error about unknown status, got: {}",
err
);
let result = manager
.list_batches(crate::batch::ListBatchesFilter {
api_key_ids: Some(vec![]),
status: Some("not_a_status".to_string()),
limit: Some(100),
..Default::default()
})
.await;
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(
err.contains("Unknown batch status filter"),
"Expected status error even with empty api_key_ids, got: {}",
err
);
}
#[sqlx::test]
async fn test_list_batches_filter_by_time_range(pool: sqlx::PgPool) {
let http_client = Arc::new(MockHttpClient::new());
let manager = PostgresRequestManager::with_client(
TestDbPools::new(pool.clone()).await.unwrap(),
http_client,
);
let file_id = manager
.create_file(
"time-filter-test".to_string(),
None,
vec![RequestTemplateInput {
custom_id: None,
endpoint: "https://api.example.com".to_string(),
method: "POST".to_string(),
path: "/v1/test".to_string(),
body: "{}".to_string(),
model: "test".to_string(),
api_key: "key".to_string(),
}],
)
.await
.unwrap();
let batch = manager
.create_batch(crate::batch::BatchInput {
file_id,
endpoint: "/v1/chat/completions".to_string(),
completion_window: "24h".to_string(),
metadata: None,
created_by: None,
api_key_id: None,
api_key: None,
total_requests: None,
})
.await
.unwrap();
let batch_created_at = batch.created_at;
let results = manager
.list_batches(crate::batch::ListBatchesFilter {
created_after: Some(batch_created_at + chrono::Duration::seconds(1)),
limit: Some(100),
..Default::default()
})
.await
.unwrap();
assert!(results.is_empty());
let results = manager
.list_batches(crate::batch::ListBatchesFilter {
created_before: Some(batch_created_at - chrono::Duration::seconds(1)),
limit: Some(100),
..Default::default()
})
.await
.unwrap();
assert!(results.is_empty());
let results = manager
.list_batches(crate::batch::ListBatchesFilter {
created_after: Some(batch_created_at - chrono::Duration::seconds(1)),
created_before: Some(batch_created_at + chrono::Duration::seconds(1)),
limit: Some(100),
..Default::default()
})
.await
.unwrap();
assert!(!results.is_empty());
}
#[sqlx::test]
async fn test_list_batches_search_by_batch_id(pool: sqlx::PgPool) {
let http_client = Arc::new(MockHttpClient::new());
let manager = PostgresRequestManager::with_client(
TestDbPools::new(pool.clone()).await.unwrap(),
http_client,
);
let file_id = manager
.create_file(
"id-search-test".to_string(),
None,
vec![RequestTemplateInput {
custom_id: None,
endpoint: "https://api.example.com".to_string(),
method: "POST".to_string(),
path: "/v1/test".to_string(),
body: "{}".to_string(),
model: "test".to_string(),
api_key: "key".to_string(),
}],
)
.await
.unwrap();
let batch = manager
.create_batch(crate::batch::BatchInput {
file_id,
endpoint: "/v1/chat/completions".to_string(),
completion_window: "24h".to_string(),
metadata: None,
created_by: None,
api_key_id: None,
api_key: None,
total_requests: None,
})
.await
.unwrap();
let id_str = batch.id.0.to_string();
let search_term = &id_str[..8];
let results = manager
.list_batches(crate::batch::ListBatchesFilter {
search: Some(search_term.to_string()),
limit: Some(100),
..Default::default()
})
.await
.unwrap();
assert!(results.iter().any(|b| b.id == batch.id));
let results = manager
.list_batches(crate::batch::ListBatchesFilter {
search: Some("zzz_no_match_zzz".to_string()),
limit: Some(100),
..Default::default()
})
.await
.unwrap();
assert!(results.iter().all(|b| b.id != batch.id));
}
#[sqlx::test]
async fn test_list_batches_filter_by_status_in_progress_includes_validating(
pool: sqlx::PgPool,
) {
let http_client = Arc::new(MockHttpClient::new());
let manager = PostgresRequestManager::with_client(
TestDbPools::new(pool.clone()).await.unwrap(),
http_client,
);
let file_id = manager
.create_file(
"validating-test".to_string(),
None,
vec![RequestTemplateInput {
custom_id: None,
endpoint: "https://api.example.com".to_string(),
method: "POST".to_string(),
path: "/v1/test".to_string(),
body: "{}".to_string(),
model: "test".to_string(),
api_key: "key".to_string(),
}],
)
.await
.unwrap();
let batch = manager
.create_batch(crate::batch::BatchInput {
file_id,
endpoint: "/v1/chat/completions".to_string(),
completion_window: "24h".to_string(),
metadata: None,
created_by: None,
api_key_id: None,
api_key: None,
total_requests: None,
})
.await
.unwrap();
sqlx::query!(
"UPDATE batches SET total_requests = 0 WHERE id = $1",
*batch.id as Uuid,
)
.execute(&pool)
.await
.unwrap();
let results = manager
.list_batches(crate::batch::ListBatchesFilter {
status: Some("in_progress".to_string()),
limit: Some(100),
..Default::default()
})
.await
.unwrap();
assert!(results.iter().any(|b| b.id == batch.id));
}
#[sqlx::test]
async fn test_list_batches_filter_by_status_cancelled(pool: sqlx::PgPool) {
let http_client = Arc::new(MockHttpClient::new());
let manager = PostgresRequestManager::with_client(
TestDbPools::new(pool.clone()).await.unwrap(),
http_client,
);
let file_id = manager
.create_file(
"cancelled-test".to_string(),
None,
vec![RequestTemplateInput {
custom_id: None,
endpoint: "https://api.example.com".to_string(),
method: "POST".to_string(),
path: "/v1/test".to_string(),
body: "{}".to_string(),
model: "test".to_string(),
api_key: "key".to_string(),
}],
)
.await
.unwrap();
let batch_a = manager
.create_batch(crate::batch::BatchInput {
file_id,
endpoint: "/v1/chat/completions".to_string(),
completion_window: "24h".to_string(),
metadata: None,
created_by: None,
api_key_id: None,
api_key: None,
total_requests: None,
})
.await
.unwrap();
let batch_b = manager
.create_batch(crate::batch::BatchInput {
file_id,
endpoint: "/v1/chat/completions".to_string(),
completion_window: "24h".to_string(),
metadata: None,
created_by: None,
api_key_id: None,
api_key: None,
total_requests: None,
})
.await
.unwrap();
manager.cancel_batch(batch_a.id).await.unwrap();
sqlx::query!(
"UPDATE batches SET cancelling_at = NOW() WHERE id = $1",
*batch_b.id as Uuid,
)
.execute(&pool)
.await
.unwrap();
let results = manager
.list_batches(crate::batch::ListBatchesFilter {
status: Some("cancelled".to_string()),
limit: Some(100),
..Default::default()
})
.await
.unwrap();
assert!(
results.iter().any(|b| b.id == batch_a.id),
"fully cancelled batch should match"
);
assert!(
results.iter().any(|b| b.id == batch_b.id),
"cancelling batch should also match"
);
let results = manager
.list_batches(crate::batch::ListBatchesFilter {
status: Some("in_progress".to_string()),
limit: Some(100),
..Default::default()
})
.await
.unwrap();
assert!(results.iter().all(|b| b.id != batch_a.id));
assert!(results.iter().all(|b| b.id != batch_b.id));
}
#[sqlx::test]
async fn test_list_batches_filter_by_status_failed(pool: sqlx::PgPool) {
let http_client = Arc::new(MockHttpClient::new());
let manager = PostgresRequestManager::with_client(
TestDbPools::new(pool.clone()).await.unwrap(),
http_client,
);
let file_id = manager
.create_file(
"failed-test".to_string(),
None,
vec![RequestTemplateInput {
custom_id: None,
endpoint: "https://api.example.com".to_string(),
method: "POST".to_string(),
path: "/v1/test".to_string(),
body: "{}".to_string(),
model: "test".to_string(),
api_key: "key".to_string(),
}],
)
.await
.unwrap();
let batch = manager
.create_batch(crate::batch::BatchInput {
file_id,
endpoint: "/v1/chat/completions".to_string(),
completion_window: "24h".to_string(),
metadata: None,
created_by: None,
api_key_id: None,
api_key: None,
total_requests: None,
})
.await
.unwrap();
sqlx::query!(
"UPDATE batches SET failed_at = NOW() WHERE id = $1",
*batch.id as Uuid,
)
.execute(&pool)
.await
.unwrap();
let results = manager
.list_batches(crate::batch::ListBatchesFilter {
status: Some("failed".to_string()),
limit: Some(100),
..Default::default()
})
.await
.unwrap();
assert!(results.iter().any(|b| b.id == batch.id));
let results = manager
.list_batches(crate::batch::ListBatchesFilter {
status: Some("in_progress".to_string()),
limit: Some(100),
..Default::default()
})
.await
.unwrap();
assert!(results.iter().all(|b| b.id != batch.id));
}
#[sqlx::test]
async fn test_list_batches_filter_by_status_expired(pool: sqlx::PgPool) {
let http_client = Arc::new(MockHttpClient::new());
let manager = PostgresRequestManager::with_client(
TestDbPools::new(pool.clone()).await.unwrap(),
http_client,
);
let file_id = manager
.create_file(
"expired-test".to_string(),
None,
vec![RequestTemplateInput {
custom_id: None,
endpoint: "https://api.example.com".to_string(),
method: "POST".to_string(),
path: "/v1/test".to_string(),
body: "{}".to_string(),
model: "test".to_string(),
api_key: "key".to_string(),
}],
)
.await
.unwrap();
let batch_a = manager
.create_batch(crate::batch::BatchInput {
file_id,
endpoint: "/v1/chat/completions".to_string(),
completion_window: "24h".to_string(),
metadata: None,
created_by: None,
api_key_id: None,
api_key: None,
total_requests: None,
})
.await
.unwrap();
sqlx::query!(
"UPDATE batches SET expires_at = NOW() - INTERVAL '1 hour' WHERE id = $1",
*batch_a.id as Uuid,
)
.execute(&pool)
.await
.unwrap();
let batch_b = manager
.create_batch(crate::batch::BatchInput {
file_id,
endpoint: "/v1/chat/completions".to_string(),
completion_window: "24h".to_string(),
metadata: None,
created_by: None,
api_key_id: None,
api_key: None,
total_requests: None,
})
.await
.unwrap();
sqlx::query!(
"UPDATE batches SET expires_at = NOW() - INTERVAL '2 hours', completed_at = NOW() - INTERVAL '1 hour' WHERE id = $1",
*batch_b.id as Uuid,
)
.execute(&pool)
.await
.unwrap();
let batch_c = manager
.create_batch(crate::batch::BatchInput {
file_id,
endpoint: "/v1/chat/completions".to_string(),
completion_window: "24h".to_string(),
metadata: None,
created_by: None,
api_key_id: None,
api_key: None,
total_requests: None,
})
.await
.unwrap();
sqlx::query!(
"UPDATE batches SET expires_at = NOW() + INTERVAL '1 hour', completed_at = NOW() WHERE id = $1",
*batch_c.id as Uuid,
)
.execute(&pool)
.await
.unwrap();
let results = manager
.list_batches(crate::batch::ListBatchesFilter {
status: Some("expired".to_string()),
limit: Some(100),
..Default::default()
})
.await
.unwrap();
assert!(
results.iter().any(|b| b.id == batch_a.id),
"overdue in-progress batch should match"
);
assert!(
results.iter().any(|b| b.id == batch_b.id),
"completed-after-deadline batch should match"
);
assert!(
results.iter().all(|b| b.id != batch_c.id),
"on-time completed batch should not match"
);
}
#[sqlx::test]
async fn test_list_files_filter_by_api_key_id(pool: sqlx::PgPool) {
use crate::batch::{FileMetadata, FileStreamItem};
use futures::stream;
let http_client = Arc::new(MockHttpClient::new());
let manager = PostgresRequestManager::with_client(
TestDbPools::new(pool.clone()).await.unwrap(),
http_client,
);
let key_a = uuid::Uuid::new_v4();
let key_b = uuid::Uuid::new_v4();
let file_a = manager
.create_file_stream(stream::iter(vec![
FileStreamItem::Metadata(FileMetadata {
filename: Some("file-a.jsonl".to_string()),
api_key_id: Some(key_a),
..Default::default()
}),
FileStreamItem::Template(RequestTemplateInput {
custom_id: None,
endpoint: "https://api.example.com".to_string(),
method: "POST".to_string(),
path: "/v1/test".to_string(),
body: "{}".to_string(),
model: "test".to_string(),
api_key: "key".to_string(),
}),
]))
.await
.map(expect_stream_success)
.unwrap();
let file_b = manager
.create_file_stream(stream::iter(vec![
FileStreamItem::Metadata(FileMetadata {
filename: Some("file-b.jsonl".to_string()),
api_key_id: Some(key_b),
..Default::default()
}),
FileStreamItem::Template(RequestTemplateInput {
custom_id: None,
endpoint: "https://api.example.com".to_string(),
method: "POST".to_string(),
path: "/v1/test".to_string(),
body: "{}".to_string(),
model: "test".to_string(),
api_key: "key".to_string(),
}),
]))
.await
.map(expect_stream_success)
.unwrap();
let results = manager
.list_files(crate::batch::FileFilter {
api_key_ids: Some(vec![key_a]),
..Default::default()
})
.await
.unwrap();
assert_eq!(results.len(), 1);
assert_eq!(results[0].id, file_a);
let results = manager
.list_files(crate::batch::FileFilter {
api_key_ids: Some(vec![key_b]),
..Default::default()
})
.await
.unwrap();
assert_eq!(results.len(), 1);
assert_eq!(results[0].id, file_b);
let results = manager
.list_files(crate::batch::FileFilter {
api_key_ids: Some(vec![key_a, key_b]),
..Default::default()
})
.await
.unwrap();
assert_eq!(results.len(), 2);
let ids: std::collections::HashSet<_> = results.iter().map(|f| f.id).collect();
assert!(ids.contains(&file_a));
assert!(ids.contains(&file_b));
let results = manager
.list_files(crate::batch::FileFilter {
api_key_ids: Some(vec![]),
..Default::default()
})
.await
.unwrap();
assert_eq!(results.len(), 0);
let results = manager
.list_files(crate::batch::FileFilter::default())
.await
.unwrap();
let ids: Vec<_> = results.iter().map(|f| f.id).collect();
assert!(ids.contains(&file_a));
assert!(ids.contains(&file_b));
}
#[sqlx::test]
async fn test_list_batches_active_first_sorting(pool: sqlx::PgPool) {
let http_client = Arc::new(MockHttpClient::new());
let manager = PostgresRequestManager::with_client(
TestDbPools::new(pool.clone()).await.unwrap(),
http_client,
);
let file_id = manager
.create_file(
"active-first-test".to_string(),
None,
vec![RequestTemplateInput {
custom_id: None,
endpoint: "https://api.example.com".to_string(),
method: "POST".to_string(),
path: "/v1/test".to_string(),
body: "{}".to_string(),
model: "test".to_string(),
api_key: "key".to_string(),
}],
)
.await
.unwrap();
let mut batch_ids = Vec::new();
for i in 0..5 {
let batch = manager
.create_batch(crate::batch::BatchInput {
file_id,
endpoint: "/v1/chat/completions".to_string(),
completion_window: "24h".to_string(),
metadata: None,
created_by: None,
api_key_id: None,
api_key: None,
total_requests: None,
})
.await
.unwrap();
sqlx::query("UPDATE batches SET created_at = NOW() + ($1 || ' seconds')::INTERVAL WHERE id = $2")
.bind(i.to_string())
.bind(*batch.id as Uuid)
.execute(&pool)
.await
.unwrap();
batch_ids.push(batch.id);
}
sqlx::query("UPDATE batches SET completed_at = NOW() WHERE id = $1")
.bind(*batch_ids[2] as Uuid)
.execute(&pool)
.await
.unwrap();
sqlx::query("UPDATE batches SET cancelling_at = NOW(), cancelled_at = NOW() WHERE id = $1")
.bind(*batch_ids[3] as Uuid)
.execute(&pool)
.await
.unwrap();
sqlx::query("UPDATE batches SET cancelling_at = NOW() WHERE id = $1")
.bind(*batch_ids[4] as Uuid)
.execute(&pool)
.await
.unwrap();
let results = manager
.list_batches(crate::batch::ListBatchesFilter {
active_first: false,
limit: Some(100),
..Default::default()
})
.await
.unwrap();
let result_ids: Vec<_> = results.iter().map(|b| b.id).collect();
assert_eq!(
result_ids,
vec![
batch_ids[4],
batch_ids[3],
batch_ids[2],
batch_ids[1],
batch_ids[0]
]
);
let results = manager
.list_batches(crate::batch::ListBatchesFilter {
active_first: true,
limit: Some(100),
..Default::default()
})
.await
.unwrap();
let result_ids: Vec<_> = results.iter().map(|b| b.id).collect();
assert_eq!(
result_ids,
vec![
batch_ids[1],
batch_ids[0],
batch_ids[4],
batch_ids[3],
batch_ids[2]
],
"Active batches should sort before newer terminal ones"
);
}
#[sqlx::test]
async fn test_list_batches_active_first_cursor_pagination(pool: sqlx::PgPool) {
let http_client = Arc::new(MockHttpClient::new());
let manager = PostgresRequestManager::with_client(
TestDbPools::new(pool.clone()).await.unwrap(),
http_client,
);
let file_id = manager
.create_file(
"cursor-test".to_string(),
None,
vec![RequestTemplateInput {
custom_id: None,
endpoint: "https://api.example.com".to_string(),
method: "POST".to_string(),
path: "/v1/test".to_string(),
body: "{}".to_string(),
model: "test".to_string(),
api_key: "key".to_string(),
}],
)
.await
.unwrap();
let mut batch_ids = Vec::new();
for i in 0..4 {
let batch = manager
.create_batch(crate::batch::BatchInput {
file_id,
endpoint: "/v1/chat/completions".to_string(),
completion_window: "24h".to_string(),
metadata: None,
created_by: None,
api_key_id: None,
api_key: None,
total_requests: None,
})
.await
.unwrap();
sqlx::query("UPDATE batches SET created_at = NOW() + ($1 || ' seconds')::INTERVAL WHERE id = $2")
.bind(i.to_string())
.bind(*batch.id as Uuid)
.execute(&pool)
.await
.unwrap();
batch_ids.push(batch.id);
}
sqlx::query("UPDATE batches SET completed_at = NOW() WHERE id = $1")
.bind(*batch_ids[2] as Uuid)
.execute(&pool)
.await
.unwrap();
sqlx::query("UPDATE batches SET completed_at = NOW() WHERE id = $1")
.bind(*batch_ids[3] as Uuid)
.execute(&pool)
.await
.unwrap();
let page1 = manager
.list_batches(crate::batch::ListBatchesFilter {
active_first: true,
limit: Some(2),
..Default::default()
})
.await
.unwrap();
assert_eq!(page1.len(), 2);
assert_eq!(page1[0].id, batch_ids[1], "newest active batch first");
assert_eq!(page1[1].id, batch_ids[0], "oldest active batch second");
let page2 = manager
.list_batches(crate::batch::ListBatchesFilter {
active_first: true,
limit: Some(2),
after: Some(page1.last().unwrap().id),
..Default::default()
})
.await
.unwrap();
assert_eq!(page2.len(), 2);
assert_eq!(page2[0].id, batch_ids[3], "newest terminal batch first");
assert_eq!(page2[1].id, batch_ids[2], "oldest terminal batch second");
let page3 = manager
.list_batches(crate::batch::ListBatchesFilter {
active_first: true,
limit: Some(2),
after: Some(page2.last().unwrap().id),
..Default::default()
})
.await
.unwrap();
assert!(
page3.is_empty(),
"Should have no more results after last page"
);
let all_ids: Vec<_> = page1.iter().chain(page2.iter()).map(|b| b.id).collect();
assert_eq!(
all_ids,
vec![batch_ids[1], batch_ids[0], batch_ids[3], batch_ids[2]],
"Full pagination should return all batches in active-first order"
);
}
}