use fusillade::TestDbPools;
use fusillade::batch::{BatchInput, RequestTemplateInput};
use fusillade::daemon::{DaemonConfig, ModelEscalationConfig, default_should_retry};
use fusillade::http::{HttpResponse, MockHttpClient};
use fusillade::manager::postgres::PostgresRequestManager;
use fusillade::manager::{DaemonExecutor, Storage};
use std::sync::Arc;
use std::time::Duration;
#[sqlx::test]
#[test_log::test]
async fn test_daemon_claims_and_completes_request(pool: sqlx::PgPool) {
let http_client = Arc::new(MockHttpClient::new());
http_client.add_response(
"POST /v1/test",
Ok(HttpResponse {
status: 200,
body: r#"{"result":"success"}"#.to_string(),
}),
);
let model_concurrency_limits = Arc::new(dashmap::DashMap::new());
model_concurrency_limits.insert("test-model".to_string(), 10);
let config = DaemonConfig {
claim_batch_size: 10,
claim_interval_ms: 10, model_concurrency_limits,
max_retries: Some(3),
stop_before_deadline_ms: None,
backoff_ms: 100,
backoff_factor: 2,
max_backoff_ms: 1000,
status_log_interval_ms: None, heartbeat_interval_ms: 10000, should_retry: Arc::new(default_should_retry),
claim_timeout_ms: 60000,
processing_timeout_ms: 600000,
cancellation_poll_interval_ms: 100, ..Default::default()
};
let manager = Arc::new(
PostgresRequestManager::with_client(
TestDbPools::new(pool.clone()).await.unwrap(),
http_client.clone(),
)
.with_config(config),
);
let file_id = manager
.create_file(
"test-file".to_string(),
Some("Test file".to_string()),
vec![fusillade::RequestTemplateInput {
custom_id: None,
endpoint: "https://api.example.com".to_string(),
method: "POST".to_string(),
path: "/v1/test".to_string(),
body: r#"{"prompt":"test"}"#.to_string(),
model: "test-model".to_string(),
api_key: "test-key".to_string(),
}],
)
.await
.expect("Failed to create file");
let batch = manager
.create_batch(fusillade::batch::BatchInput {
file_id,
endpoint: "/v1/chat/completions".to_string(),
completion_window: "24h".to_string(),
metadata: None,
created_by: None,
api_key_id: None,
api_key: None,
total_requests: None,
})
.await
.expect("Failed to create batch");
let requests = manager
.get_batch_requests(batch.id)
.await
.expect("Failed to get batch requests");
assert_eq!(requests.len(), 1);
let request_id = requests[0].id();
let shutdown_token = tokio_util::sync::CancellationToken::new();
manager
.clone()
.run(shutdown_token.clone())
.expect("Failed to start daemon");
let start = tokio::time::Instant::now();
let timeout = Duration::from_secs(5);
let mut completed = false;
while start.elapsed() < timeout {
let results = manager
.get_requests(vec![request_id])
.await
.expect("Failed to get request");
if let Some(Ok(any_request)) = results.first()
&& any_request.is_terminal()
{
if let fusillade::AnyRequest::Completed(req) = any_request {
assert_eq!(req.state.response_status, 200);
assert_eq!(req.state.response_body, r#"{"result":"success"}"#);
completed = true;
break;
} else {
panic!(
"Request reached terminal state but was not completed: {:?}",
any_request
);
}
}
tokio::time::sleep(Duration::from_millis(50)).await;
}
shutdown_token.cancel();
assert!(
completed,
"Request did not complete within timeout. Check daemon processing."
);
assert_eq!(http_client.call_count(), 1);
let calls = http_client.get_calls();
assert_eq!(calls[0].method, "POST");
assert_eq!(calls[0].path, "/v1/test");
assert_eq!(calls[0].api_key, "test-key");
}
#[sqlx::test]
async fn test_daemon_respects_per_model_concurrency_limits(pool: sqlx::PgPool) {
let http_client = Arc::new(MockHttpClient::new());
let trigger1 = http_client.add_response_with_trigger(
"POST /v1/test",
Ok(HttpResponse {
status: 200,
body: r#"{"result":"1"}"#.to_string(),
}),
);
let trigger2 = http_client.add_response_with_trigger(
"POST /v1/test",
Ok(HttpResponse {
status: 200,
body: r#"{"result":"2"}"#.to_string(),
}),
);
let trigger3 = http_client.add_response_with_trigger(
"POST /v1/test",
Ok(HttpResponse {
status: 200,
body: r#"{"result":"3"}"#.to_string(),
}),
);
let trigger4 = http_client.add_response_with_trigger(
"POST /v1/test",
Ok(HttpResponse {
status: 200,
body: r#"{"result":"4"}"#.to_string(),
}),
);
let trigger5 = http_client.add_response_with_trigger(
"POST /v1/test",
Ok(HttpResponse {
status: 200,
body: r#"{"result":"5"}"#.to_string(),
}),
);
let model_concurrency_limits = Arc::new(dashmap::DashMap::new());
model_concurrency_limits.insert("gpt-4".to_string(), 2);
let config = DaemonConfig {
claim_batch_size: 10,
claim_interval_ms: 10,
model_concurrency_limits,
max_retries: Some(3),
stop_before_deadline_ms: None,
backoff_ms: 100,
backoff_factor: 2,
max_backoff_ms: 1000,
status_log_interval_ms: None,
heartbeat_interval_ms: 10000,
should_retry: Arc::new(default_should_retry),
claim_timeout_ms: 60000,
processing_timeout_ms: 600000,
cancellation_poll_interval_ms: 100, ..Default::default()
};
let manager = Arc::new(
PostgresRequestManager::with_client(
TestDbPools::new(pool.clone()).await.unwrap(),
http_client.clone(),
)
.with_config(config),
);
let file_id = manager
.create_file(
"test-file".to_string(),
Some("Test concurrency limits".to_string()),
vec![
fusillade::RequestTemplateInput {
custom_id: None,
endpoint: "https://api.example.com".to_string(),
method: "POST".to_string(),
path: "/v1/test".to_string(),
body: r#"{"prompt":"test1"}"#.to_string(),
model: "gpt-4".to_string(),
api_key: "test-key".to_string(),
},
fusillade::RequestTemplateInput {
custom_id: None,
endpoint: "https://api.example.com".to_string(),
method: "POST".to_string(),
path: "/v1/test".to_string(),
body: r#"{"prompt":"test2"}"#.to_string(),
model: "gpt-4".to_string(),
api_key: "test-key".to_string(),
},
fusillade::RequestTemplateInput {
custom_id: None,
endpoint: "https://api.example.com".to_string(),
method: "POST".to_string(),
path: "/v1/test".to_string(),
body: r#"{"prompt":"test3"}"#.to_string(),
model: "gpt-4".to_string(),
api_key: "test-key".to_string(),
},
fusillade::RequestTemplateInput {
custom_id: None,
endpoint: "https://api.example.com".to_string(),
method: "POST".to_string(),
path: "/v1/test".to_string(),
body: r#"{"prompt":"test4"}"#.to_string(),
model: "gpt-4".to_string(),
api_key: "test-key".to_string(),
},
fusillade::RequestTemplateInput {
custom_id: None,
endpoint: "https://api.example.com".to_string(),
method: "POST".to_string(),
path: "/v1/test".to_string(),
body: r#"{"prompt":"test5"}"#.to_string(),
model: "gpt-4".to_string(),
api_key: "test-key".to_string(),
},
],
)
.await
.expect("Failed to create file");
let batch = manager
.create_batch(fusillade::batch::BatchInput {
file_id,
endpoint: "/v1/chat/completions".to_string(),
completion_window: "24h".to_string(),
metadata: None,
created_by: None,
api_key_id: None,
api_key: None,
total_requests: None,
})
.await
.expect("Failed to create batch");
let shutdown_token = tokio_util::sync::CancellationToken::new();
manager
.clone()
.run(shutdown_token.clone())
.expect("Failed to start daemon");
let start = tokio::time::Instant::now();
let timeout = Duration::from_secs(2);
let mut reached_limit = false;
while start.elapsed() < timeout {
let in_flight = http_client.in_flight_count();
if in_flight == 2 {
reached_limit = true;
break;
}
tokio::time::sleep(Duration::from_millis(10)).await;
}
assert!(
reached_limit,
"Expected exactly 2 requests in-flight, got {}",
http_client.in_flight_count()
);
let start = tokio::time::Instant::now();
let stable_duration = Duration::from_millis(100);
while start.elapsed() < stable_duration {
let in_flight = http_client.in_flight_count();
assert!(
in_flight <= 2,
"Concurrency limit violated: {} requests in-flight (expected max 2)",
in_flight
);
tokio::time::sleep(Duration::from_millis(10)).await;
}
assert_eq!(
http_client.in_flight_count(),
2,
"Expected exactly 2 requests in-flight after stability check"
);
trigger1.send(()).unwrap();
let start = tokio::time::Instant::now();
let timeout = Duration::from_secs(2);
let mut third_started = false;
while start.elapsed() < timeout {
if http_client.call_count() >= 3 {
third_started = true;
break;
}
tokio::time::sleep(Duration::from_millis(10)).await;
}
assert!(
third_started,
"Third request should have started after first completed"
);
assert_eq!(
http_client.in_flight_count(),
2,
"Should maintain concurrency limit of 2"
);
trigger2.send(()).unwrap();
trigger3.send(()).unwrap();
trigger4.send(()).unwrap();
trigger5.send(()).unwrap();
let start = tokio::time::Instant::now();
let timeout = Duration::from_secs(5);
let mut all_completed = false;
while start.elapsed() < timeout {
let status = manager
.get_batch_status(batch.id)
.await
.expect("Failed to get batch status");
if status.completed_requests == 5 {
all_completed = true;
break;
}
tokio::time::sleep(Duration::from_millis(50)).await;
}
shutdown_token.cancel();
assert!(all_completed, "All 5 requests should have completed");
assert_eq!(http_client.call_count(), 5);
}
#[sqlx::test]
async fn test_daemon_retries_failed_requests(pool: sqlx::PgPool) {
let http_client = Arc::new(MockHttpClient::new());
http_client.add_response(
"POST /v1/test",
Ok(HttpResponse {
status: 500,
body: r#"{"error":"internal error"}"#.to_string(),
}),
);
http_client.add_response(
"POST /v1/test",
Ok(HttpResponse {
status: 503,
body: r#"{"error":"service unavailable"}"#.to_string(),
}),
);
http_client.add_response(
"POST /v1/test",
Ok(HttpResponse {
status: 200,
body: r#"{"result":"success after retries"}"#.to_string(),
}),
);
let model_concurrency_limits = Arc::new(dashmap::DashMap::new());
model_concurrency_limits.insert("test-model".to_string(), 10);
let config = DaemonConfig {
claim_batch_size: 10,
claim_interval_ms: 10,
model_concurrency_limits,
max_retries: Some(5),
stop_before_deadline_ms: None,
backoff_ms: 10, backoff_factor: 2,
max_backoff_ms: 100,
status_log_interval_ms: None,
heartbeat_interval_ms: 10000,
should_retry: Arc::new(default_should_retry),
claim_timeout_ms: 60000,
processing_timeout_ms: 600000,
cancellation_poll_interval_ms: 100,
..Default::default()
};
let manager = Arc::new(
PostgresRequestManager::with_client(
TestDbPools::new(pool.clone()).await.unwrap(),
http_client.clone(),
)
.with_config(config),
);
let file_id = manager
.create_file(
"test-file".to_string(),
Some("Test retry logic".to_string()),
vec![fusillade::RequestTemplateInput {
custom_id: None,
endpoint: "https://api.example.com".to_string(),
method: "POST".to_string(),
path: "/v1/test".to_string(),
body: r#"{"prompt":"test"}"#.to_string(),
model: "test-model".to_string(),
api_key: "test-key".to_string(),
}],
)
.await
.expect("Failed to create file");
let batch = manager
.create_batch(fusillade::batch::BatchInput {
file_id,
endpoint: "/v1/chat/completions".to_string(),
completion_window: "24h".to_string(),
metadata: None,
created_by: None,
api_key_id: None,
api_key: None,
total_requests: None,
})
.await
.expect("Failed to create batch");
let requests = manager
.get_batch_requests(batch.id)
.await
.expect("Failed to get batch requests");
assert_eq!(requests.len(), 1);
let request_id = requests[0].id();
let shutdown_token = tokio_util::sync::CancellationToken::new();
manager
.clone()
.run(shutdown_token.clone())
.expect("Failed to start daemon");
let start = tokio::time::Instant::now();
let timeout = Duration::from_secs(5);
let mut completed = false;
while start.elapsed() < timeout {
let results = manager
.get_requests(vec![request_id])
.await
.expect("Failed to get request");
if let Some(Ok(any_request)) = results.first()
&& let fusillade::AnyRequest::Completed(req) = any_request
{
assert_eq!(req.state.response_status, 200);
assert_eq!(
req.state.response_body,
r#"{"result":"success after retries"}"#
);
completed = true;
break;
}
tokio::time::sleep(Duration::from_millis(50)).await;
}
shutdown_token.cancel();
assert!(completed, "Request should have completed after retries");
assert_eq!(
http_client.call_count(),
3,
"Expected 3 HTTP calls (2 failed attempts + 1 success)"
);
}
#[sqlx::test]
async fn test_daemon_dynamically_updates_concurrency_limits(pool: sqlx::PgPool) {
let http_client = Arc::new(MockHttpClient::new());
let mut triggers = vec![];
for i in 1..=10 {
let trigger = http_client.add_response_with_trigger(
"POST /v1/test",
Ok(HttpResponse {
status: 200,
body: format!(r#"{{"result":"{}"}}"#, i),
}),
);
triggers.push(trigger);
}
let model_concurrency_limits = Arc::new(dashmap::DashMap::new());
model_concurrency_limits.insert("gpt-4".to_string(), 2);
let config = DaemonConfig {
claim_batch_size: 10,
claim_interval_ms: 10,
model_concurrency_limits: model_concurrency_limits.clone(),
max_retries: Some(3),
stop_before_deadline_ms: None,
backoff_ms: 100,
backoff_factor: 2,
max_backoff_ms: 1000,
status_log_interval_ms: None,
heartbeat_interval_ms: 10000,
should_retry: Arc::new(default_should_retry),
claim_timeout_ms: 60000,
processing_timeout_ms: 600000,
cancellation_poll_interval_ms: 100,
..Default::default()
};
let manager = Arc::new(
PostgresRequestManager::with_client(
TestDbPools::new(pool.clone()).await.unwrap(),
http_client.clone(),
)
.with_config(config),
);
let templates: Vec<_> = (1..=10)
.map(|i| fusillade::RequestTemplateInput {
custom_id: None,
endpoint: "https://api.example.com".to_string(),
method: "POST".to_string(),
path: "/v1/test".to_string(),
body: format!(r#"{{"prompt":"test{}"}}"#, i),
model: "gpt-4".to_string(),
api_key: "test-key".to_string(),
})
.collect();
let file_id = manager
.create_file(
"test-file".to_string(),
Some("Test dynamic limits".to_string()),
templates,
)
.await
.expect("Failed to create file");
let batch = manager
.create_batch(fusillade::batch::BatchInput {
file_id,
endpoint: "/v1/chat/completions".to_string(),
completion_window: "24h".to_string(),
metadata: None,
created_by: None,
api_key_id: None,
api_key: None,
total_requests: None,
})
.await
.expect("Failed to create batch");
let shutdown_token = tokio_util::sync::CancellationToken::new();
manager
.clone()
.run(shutdown_token.clone())
.expect("Failed to start daemon");
let start = tokio::time::Instant::now();
let timeout = Duration::from_secs(2);
let mut reached_initial_limit = false;
while start.elapsed() < timeout {
let in_flight = http_client.in_flight_count();
if in_flight == 2 {
reached_initial_limit = true;
break;
}
tokio::time::sleep(Duration::from_millis(10)).await;
}
assert!(
reached_initial_limit,
"Expected exactly 2 requests in-flight with initial limit"
);
model_concurrency_limits.insert("gpt-4".to_string(), 5);
triggers.remove(0).send(()).unwrap();
let start = tokio::time::Instant::now();
let timeout = Duration::from_secs(2);
let mut reached_new_limit = false;
while start.elapsed() < timeout {
let in_flight = http_client.in_flight_count();
if in_flight >= 4 {
reached_new_limit = true;
break;
}
tokio::time::sleep(Duration::from_millis(10)).await;
}
assert!(
reached_new_limit,
"Expected more requests in-flight after limit increase, got {}",
http_client.in_flight_count()
);
model_concurrency_limits.insert("gpt-4".to_string(), 3);
for trigger in triggers {
trigger.send(()).unwrap();
}
let start = tokio::time::Instant::now();
let timeout = Duration::from_secs(5);
let mut all_completed = false;
while start.elapsed() < timeout {
let status = manager
.get_batch_status(batch.id)
.await
.expect("Failed to get batch status");
if status.completed_requests == 10 {
all_completed = true;
break;
}
tokio::time::sleep(Duration::from_millis(50)).await;
}
shutdown_token.cancel();
assert!(all_completed, "All 10 requests should have completed");
assert_eq!(http_client.call_count(), 10);
}
#[sqlx::test]
async fn test_deadline_aware_retry_stops_before_deadline(pool: sqlx::PgPool) {
let http_client = Arc::new(MockHttpClient::new());
for _ in 0..20 {
http_client.add_response(
"POST /v1/test",
Ok(HttpResponse {
status: 500,
body: r#"{"error":"server error"}"#.to_string(),
}),
);
}
let model_concurrency_limits = Arc::new(dashmap::DashMap::new());
model_concurrency_limits.insert("test-model".to_string(), 10);
let config = DaemonConfig {
claim_batch_size: 10,
claim_interval_ms: 10,
model_concurrency_limits,
max_retries: Some(10_000),
stop_before_deadline_ms: Some(500), backoff_ms: 50,
backoff_factor: 2,
max_backoff_ms: 200,
status_log_interval_ms: None,
heartbeat_interval_ms: 10000,
should_retry: Arc::new(default_should_retry),
claim_timeout_ms: 60000,
processing_timeout_ms: 600000,
cancellation_poll_interval_ms: 100,
..Default::default()
};
let manager = Arc::new(
PostgresRequestManager::with_client(
TestDbPools::new(pool.clone()).await.unwrap(),
http_client.clone(),
)
.with_config(config),
);
let file_id = manager
.create_file(
"test-file".to_string(),
Some("Test deadline cutoff".to_string()),
vec![fusillade::RequestTemplateInput {
custom_id: None,
endpoint: "https://api.example.com".to_string(),
method: "POST".to_string(),
path: "/v1/test".to_string(),
body: r#"{"prompt":"test"}"#.to_string(),
model: "test-model".to_string(),
api_key: "test-key".to_string(),
}],
)
.await
.expect("Failed to create file");
let batch = manager
.create_batch(fusillade::batch::BatchInput {
file_id,
endpoint: "/v1/chat/completions".to_string(),
completion_window: "2s".to_string(), metadata: None,
created_by: None,
api_key_id: None,
api_key: None,
total_requests: None,
})
.await
.expect("Failed to create batch");
let requests = manager
.get_batch_requests(batch.id)
.await
.expect("Failed to get batch requests");
let request_id = requests[0].id();
let shutdown_token = tokio_util::sync::CancellationToken::new();
manager
.clone()
.run(shutdown_token.clone())
.expect("Failed to start daemon");
let start = tokio::time::Instant::now();
let timeout = Duration::from_secs(5);
let mut results = None;
while start.elapsed() < timeout {
let res = manager
.get_requests(vec![request_id])
.await
.expect("Failed to get request");
if let Some(Ok(req)) = res.first()
&& req.is_terminal()
{
results = Some(res);
break;
}
tokio::time::sleep(Duration::from_millis(50)).await;
}
shutdown_token.cancel();
let results = results.expect("Request should have reached terminal state within timeout");
if let Some(Ok(fusillade::AnyRequest::Failed(failed))) = results.first() {
let retry_count = failed.state.retry_attempt;
let call_count = http_client.call_count();
assert!(
(7..=9).contains(&retry_count),
"Expected 7-9 retry attempts based on deadline and backoff calculation, got {}",
retry_count
);
assert_eq!(
call_count,
(retry_count + 1) as usize,
"Expected call count to match retry attempts + 1 initial attempt, got {} calls for {} retry attempts",
call_count,
retry_count
);
assert!(
!failed.state.reason.to_error_message().is_empty(),
"Expected failed request to have failure reason"
);
} else {
panic!(
"Expected request to be in Failed state, got {:?}",
results.first()
);
}
}
#[sqlx::test]
async fn test_retry_stops_at_deadline_when_no_limits_set(pool: sqlx::PgPool) {
let http_client = Arc::new(MockHttpClient::new());
for _ in 0..20 {
http_client.add_response(
"POST /v1/test",
Ok(HttpResponse {
status: 500,
body: r#"{"error":"server error"}"#.to_string(),
}),
);
}
let model_concurrency_limits = Arc::new(dashmap::DashMap::new());
model_concurrency_limits.insert("test-model".to_string(), 10);
let config = DaemonConfig {
claim_batch_size: 10,
claim_interval_ms: 10,
model_concurrency_limits,
max_retries: None, stop_before_deadline_ms: None, backoff_ms: 50,
backoff_factor: 2,
max_backoff_ms: 200,
status_log_interval_ms: None,
heartbeat_interval_ms: 10000,
should_retry: Arc::new(default_should_retry),
claim_timeout_ms: 60000,
processing_timeout_ms: 600000,
cancellation_poll_interval_ms: 100,
..Default::default()
};
let manager = Arc::new(
PostgresRequestManager::with_client(
TestDbPools::new(pool.clone()).await.unwrap(),
http_client.clone(),
)
.with_config(config),
);
let file_id = manager
.create_file(
"test-file".to_string(),
Some("Test no limits retry".to_string()),
vec![fusillade::RequestTemplateInput {
custom_id: None,
endpoint: "https://api.example.com".to_string(),
method: "POST".to_string(),
path: "/v1/test".to_string(),
body: r#"{"prompt":"test"}"#.to_string(),
model: "test-model".to_string(),
api_key: "test-key".to_string(),
}],
)
.await
.expect("Failed to create file");
let batch = manager
.create_batch(fusillade::batch::BatchInput {
file_id,
endpoint: "/v1/chat/completions".to_string(),
completion_window: "2s".to_string(),
metadata: None,
created_by: None,
api_key_id: None,
api_key: None,
total_requests: None,
})
.await
.expect("Failed to create batch");
let requests = manager
.get_batch_requests(batch.id)
.await
.expect("Failed to get batch requests");
let request_id = requests[0].id();
let shutdown_token = tokio_util::sync::CancellationToken::new();
manager
.clone()
.run(shutdown_token.clone())
.expect("Failed to start daemon");
let start = tokio::time::Instant::now();
let timeout = Duration::from_secs(5);
let mut results = None;
while start.elapsed() < timeout {
let res = manager
.get_requests(vec![request_id])
.await
.expect("Failed to get request");
if let Some(Ok(req)) = res.first()
&& req.is_terminal()
{
results = Some(res);
break;
}
tokio::time::sleep(Duration::from_millis(50)).await;
}
shutdown_token.cancel();
let results = results.expect("Request should have reached terminal state within timeout");
if let Some(Ok(fusillade::AnyRequest::Failed(failed))) = results.first() {
let retry_count = failed.state.retry_attempt;
let call_count = http_client.call_count();
assert!(
(9..12).contains(&retry_count),
"Expected 9-12 retry attempts (should retry until deadline with no buffer), got {}",
retry_count
);
assert_eq!(
call_count,
(retry_count + 1) as usize,
"Expected call count to match retry attempts + 1 initial attempt, got {} calls for {} retry attempts",
call_count,
retry_count
);
assert!(
!failed.state.reason.to_error_message().is_empty(),
"Expected failed request to have failure reason"
);
} else {
panic!(
"Expected request to be in Failed state, got {:?}",
results.first()
);
}
}
#[sqlx::test]
async fn test_route_at_claim_time_escalation(pool: sqlx::PgPool) {
let http_client = Arc::new(MockHttpClient::new());
http_client.add_response(
"POST /v1/test",
Ok(HttpResponse {
status: 200,
body: r#"{"result":"escalated response"}"#.to_string(),
}),
);
let model_escalations = Arc::new(dashmap::DashMap::new());
model_escalations.insert(
"gpt-4".to_string(),
ModelEscalationConfig {
escalation_model: "gpt-4-turbo".to_string(),
escalation_threshold_seconds: 7200,
},
);
let model_concurrency_limits = Arc::new(dashmap::DashMap::new());
model_concurrency_limits.insert("gpt-4".to_string(), 10);
model_concurrency_limits.insert("gpt-4-turbo".to_string(), 10);
let config = DaemonConfig {
claim_batch_size: 10,
claim_interval_ms: 10,
model_concurrency_limits,
model_escalations,
max_retries: Some(3),
stop_before_deadline_ms: None,
backoff_ms: 100,
backoff_factor: 2,
max_backoff_ms: 1000,
status_log_interval_ms: None,
heartbeat_interval_ms: 10000,
should_retry: Arc::new(default_should_retry),
claim_timeout_ms: 60000,
processing_timeout_ms: 600000,
cancellation_poll_interval_ms: 100,
..Default::default()
};
let manager = Arc::new(
PostgresRequestManager::with_client(
TestDbPools::new(pool.clone()).await.unwrap(),
http_client.clone(),
)
.with_config(config),
);
let file_id = manager
.create_file(
"test-escalation".to_string(),
Some("Test route-at-claim-time escalation".to_string()),
vec![fusillade::RequestTemplateInput {
custom_id: Some("escalation-test".to_string()),
endpoint: "https://api.example.com".to_string(),
method: "POST".to_string(),
path: "/v1/test".to_string(),
body: r#"{"prompt":"test"}"#.to_string(),
model: "gpt-4".to_string(),
api_key: "original-key".to_string(),
}],
)
.await
.expect("Failed to create file");
let batch = manager
.create_batch(fusillade::batch::BatchInput {
file_id,
endpoint: "/v1/chat/completions".to_string(),
completion_window: "1h".to_string(),
metadata: None,
created_by: None,
api_key_id: None,
api_key: None,
total_requests: None,
})
.await
.expect("Failed to create batch");
let requests = manager
.get_batch_requests(batch.id)
.await
.expect("Failed to get batch requests");
assert_eq!(requests.len(), 1);
let request_id = requests[0].id();
let shutdown_token = tokio_util::sync::CancellationToken::new();
manager
.clone()
.run(shutdown_token.clone())
.expect("Failed to start daemon");
let start = tokio::time::Instant::now();
let timeout = Duration::from_secs(5);
let mut completed = false;
while start.elapsed() < timeout {
let results = manager
.get_requests(vec![request_id])
.await
.expect("Failed to get request");
if let Some(Ok(any_request)) = results.first()
&& any_request.is_terminal()
{
if let fusillade::AnyRequest::Completed(req) = any_request {
assert_eq!(req.state.response_status, 200);
assert_eq!(
req.state.response_body,
r#"{"result":"escalated response"}"#
);
completed = true;
break;
} else {
panic!(
"Request reached terminal state but was not completed: {:?}",
any_request
);
}
}
tokio::time::sleep(Duration::from_millis(50)).await;
}
shutdown_token.cancel();
assert!(completed, "Request did not complete within timeout");
assert_eq!(http_client.call_count(), 1);
let calls = http_client.get_calls();
assert_eq!(
calls[0].api_key, "original-key",
"Should use original API key (escalation only changes model routing)"
);
}
#[sqlx::test]
async fn test_route_at_claim_time_no_escalation_when_enough_time(pool: sqlx::PgPool) {
let http_client = Arc::new(MockHttpClient::new());
http_client.add_response(
"POST /v1/test",
Ok(HttpResponse {
status: 200,
body: r#"{"result":"normal response"}"#.to_string(),
}),
);
let model_escalations = Arc::new(dashmap::DashMap::new());
model_escalations.insert(
"gpt-4".to_string(),
ModelEscalationConfig {
escalation_model: "gpt-4-turbo".to_string(),
escalation_threshold_seconds: 60, },
);
let model_concurrency_limits = Arc::new(dashmap::DashMap::new());
model_concurrency_limits.insert("gpt-4".to_string(), 10);
model_concurrency_limits.insert("gpt-4-turbo".to_string(), 10);
let config = DaemonConfig {
claim_batch_size: 10,
claim_interval_ms: 10,
model_concurrency_limits,
model_escalations,
max_retries: Some(3),
stop_before_deadline_ms: None,
backoff_ms: 100,
backoff_factor: 2,
max_backoff_ms: 1000,
status_log_interval_ms: None,
heartbeat_interval_ms: 10000,
should_retry: Arc::new(default_should_retry),
claim_timeout_ms: 60000,
processing_timeout_ms: 600000,
cancellation_poll_interval_ms: 100,
..Default::default()
};
let manager = Arc::new(
PostgresRequestManager::with_client(
TestDbPools::new(pool.clone()).await.unwrap(),
http_client.clone(),
)
.with_config(config),
);
let file_id = manager
.create_file(
"test-no-escalation".to_string(),
Some("Test no escalation when enough time".to_string()),
vec![fusillade::RequestTemplateInput {
custom_id: Some("no-escalation-test".to_string()),
endpoint: "https://api.example.com".to_string(),
method: "POST".to_string(),
path: "/v1/test".to_string(),
body: r#"{"prompt":"test"}"#.to_string(),
model: "gpt-4".to_string(),
api_key: "original-key".to_string(),
}],
)
.await
.expect("Failed to create file");
let batch = manager
.create_batch(fusillade::batch::BatchInput {
file_id,
endpoint: "/v1/chat/completions".to_string(),
completion_window: "24h".to_string(),
metadata: None,
created_by: None,
api_key_id: None,
api_key: None,
total_requests: None,
})
.await
.expect("Failed to create batch");
let requests = manager
.get_batch_requests(batch.id)
.await
.expect("Failed to get batch requests");
let request_id = requests[0].id();
let shutdown_token = tokio_util::sync::CancellationToken::new();
manager
.clone()
.run(shutdown_token.clone())
.expect("Failed to start daemon");
let start = tokio::time::Instant::now();
let timeout = Duration::from_secs(5);
let mut completed = false;
while start.elapsed() < timeout {
let results = manager
.get_requests(vec![request_id])
.await
.expect("Failed to get request");
if let Some(Ok(any_request)) = results.first()
&& any_request.is_terminal()
{
completed = true;
break;
}
tokio::time::sleep(Duration::from_millis(50)).await;
}
shutdown_token.cancel();
assert!(completed, "Request did not complete within timeout");
assert_eq!(http_client.call_count(), 1);
let calls = http_client.get_calls();
assert_eq!(
calls[0].api_key, "original-key",
"Should use original API key when time remaining is above threshold"
);
}
mod batch_results_stream {
use super::*;
use futures::StreamExt;
async fn collect_batch_results(
manager: &PostgresRequestManager<TestDbPools, MockHttpClient>,
batch_id: fusillade::batch::BatchId,
) -> Vec<fusillade::batch::BatchResultItem> {
let stream = manager.get_batch_results_stream(batch_id, 0, None, None);
stream
.filter_map(|r| async { r.ok() })
.collect::<Vec<_>>()
.await
}
#[sqlx::test]
async fn test_batch_results_basic(pool: sqlx::PgPool) {
let http_client = Arc::new(MockHttpClient::new());
http_client.add_response(
"POST /v1/test",
Ok(HttpResponse {
status: 200,
body: r#"{"result":"success1"}"#.to_string(),
}),
);
http_client.add_response(
"POST /v1/test",
Ok(HttpResponse {
status: 200,
body: r#"{"result":"success2"}"#.to_string(),
}),
);
let model_concurrency_limits = Arc::new(dashmap::DashMap::new());
model_concurrency_limits.insert("gpt-4".to_string(), 10);
let config = DaemonConfig {
claim_batch_size: 10,
claim_interval_ms: 10,
model_concurrency_limits,
..Default::default()
};
let manager = Arc::new(
PostgresRequestManager::with_client(
TestDbPools::new(pool.clone()).await.unwrap(),
http_client.clone(),
)
.with_config(config),
);
let file_id = manager
.create_file(
"test-results".to_string(),
None,
vec![
RequestTemplateInput {
custom_id: Some("req-1".to_string()),
endpoint: "https://api.example.com".to_string(),
method: "POST".to_string(),
path: "/v1/test".to_string(),
body: r#"{"prompt":"test1"}"#.to_string(),
model: "gpt-4".to_string(),
api_key: "test-key".to_string(),
},
RequestTemplateInput {
custom_id: Some("req-2".to_string()),
endpoint: "https://api.example.com".to_string(),
method: "POST".to_string(),
path: "/v1/test".to_string(),
body: r#"{"prompt":"test2"}"#.to_string(),
model: "gpt-4".to_string(),
api_key: "test-key".to_string(),
},
],
)
.await
.expect("Failed to create file");
let batch = manager
.create_batch(BatchInput {
file_id,
endpoint: "/v1/chat/completions".to_string(),
completion_window: "24h".to_string(),
metadata: None,
created_by: None,
api_key_id: None,
api_key: None,
total_requests: None,
})
.await
.expect("Failed to create batch");
let shutdown_token = tokio_util::sync::CancellationToken::new();
manager
.clone()
.run(shutdown_token.clone())
.expect("Failed to start daemon");
let start = tokio::time::Instant::now();
while start.elapsed() < Duration::from_secs(5) {
let requests = manager
.get_batch_requests(batch.id)
.await
.expect("Failed to get requests");
if requests.iter().all(|r| r.is_terminal()) {
break;
}
tokio::time::sleep(Duration::from_millis(50)).await;
}
shutdown_token.cancel();
let results = collect_batch_results(&manager, batch.id).await;
assert_eq!(results.len(), 2, "Should have 2 results");
let custom_ids: Vec<_> = results
.iter()
.filter_map(|r| r.custom_id.as_ref())
.collect();
assert!(custom_ids.contains(&&"req-1".to_string()));
assert!(custom_ids.contains(&&"req-2".to_string()));
for result in &results {
assert_eq!(
result.status,
fusillade::batch::BatchResultStatus::Completed
);
assert!(result.response_body.is_some());
}
}
#[sqlx::test]
async fn test_batch_results_deleted_file_returns_error(pool: sqlx::PgPool) {
let http_client = Arc::new(MockHttpClient::new());
let manager = Arc::new(PostgresRequestManager::with_client(
TestDbPools::new(pool.clone()).await.unwrap(),
http_client,
));
let file_id = manager
.create_file(
"to-delete".to_string(),
None,
vec![RequestTemplateInput {
custom_id: Some("test".to_string()),
endpoint: "https://api.example.com".to_string(),
method: "POST".to_string(),
path: "/v1/test".to_string(),
body: r#"{"prompt":"test"}"#.to_string(),
model: "gpt-4".to_string(),
api_key: "test-key".to_string(),
}],
)
.await
.expect("Failed to create file");
let batch = manager
.create_batch(BatchInput {
file_id,
endpoint: "/v1/chat/completions".to_string(),
completion_window: "24h".to_string(),
metadata: None,
created_by: None,
api_key_id: None,
api_key: None,
total_requests: None,
})
.await
.expect("Failed to create batch");
sqlx::query!(
"UPDATE batches SET file_id = NULL WHERE id = $1",
*batch.id as uuid::Uuid
)
.execute(&pool)
.await
.expect("Failed to clear file_id");
let stream = manager.get_batch_results_stream(batch.id, 0, None, None);
let results: Vec<_> = stream.collect().await;
assert_eq!(results.len(), 1, "Should have one result (the error)");
assert!(
results[0].is_err(),
"Result should be an error when file_id is NULL"
);
let err = results[0].as_ref().unwrap_err();
assert!(
err.to_string().contains("file_id"),
"Error should mention file_id: {}",
err
);
}
#[sqlx::test]
async fn test_output_file_streamable_after_batch_deleted(pool: sqlx::PgPool) {
let http_client = Arc::new(MockHttpClient::new());
http_client.add_response(
"POST /v1/test",
Ok(HttpResponse {
status: 200,
body: r#"{"result":"success"}"#.to_string(),
}),
);
let model_concurrency_limits = Arc::new(dashmap::DashMap::new());
model_concurrency_limits.insert("gpt-4".to_string(), 10);
let config = DaemonConfig {
claim_batch_size: 10,
claim_interval_ms: 10,
model_concurrency_limits,
..Default::default()
};
let manager = Arc::new(
PostgresRequestManager::with_client(
TestDbPools::new(pool.clone()).await.unwrap(),
http_client.clone(),
)
.with_config(config),
);
let file_id = manager
.create_file(
"output-after-delete".to_string(),
None,
vec![RequestTemplateInput {
custom_id: Some("test".to_string()),
endpoint: "https://api.example.com".to_string(),
method: "POST".to_string(),
path: "/v1/test".to_string(),
body: r#"{"prompt":"test"}"#.to_string(),
model: "gpt-4".to_string(),
api_key: "test-key".to_string(),
}],
)
.await
.expect("Failed to create file");
let batch = manager
.create_batch(BatchInput {
file_id,
endpoint: "/v1/chat/completions".to_string(),
completion_window: "24h".to_string(),
metadata: None,
created_by: None,
api_key_id: None,
api_key: None,
total_requests: None,
})
.await
.expect("Failed to create batch");
let shutdown_token = tokio_util::sync::CancellationToken::new();
manager
.clone()
.run(shutdown_token.clone())
.expect("Failed to start daemon");
let start = tokio::time::Instant::now();
while start.elapsed() < Duration::from_secs(5) {
let b = manager.get_batch(batch.id).await.expect("get_batch failed");
if b.completed_requests == 1 {
break;
}
tokio::time::sleep(Duration::from_millis(50)).await;
}
shutdown_token.cancel();
let batch = manager.get_batch(batch.id).await.expect("get_batch failed");
let output_file_id = batch
.output_file_id
.expect("Batch should have output_file_id");
let results_before = manager
.get_file_content(output_file_id)
.await
.expect("Should be able to get output file content before deletion");
assert!(
!results_before.is_empty(),
"Output file should have content before deletion"
);
manager
.delete_batch(batch.id)
.await
.expect("delete_batch failed");
let batch_result = manager.get_batch(batch.id).await;
assert!(
batch_result.is_err(),
"Batch should not be found after deletion"
);
let results_after = manager
.get_file_content(output_file_id)
.await
.expect("Should still be able to get output file content after batch deletion");
assert!(
!results_after.is_empty(),
"Output file should still have content after batch deletion"
);
}
#[sqlx::test]
async fn test_batch_results_pagination(pool: sqlx::PgPool) {
let http_client = Arc::new(MockHttpClient::new());
for i in 0..5 {
http_client.add_response(
"POST /v1/test",
Ok(HttpResponse {
status: 200,
body: format!(r#"{{"result":"success{}"}}"#, i),
}),
);
}
let model_concurrency_limits = Arc::new(dashmap::DashMap::new());
model_concurrency_limits.insert("gpt-4".to_string(), 10);
let config = DaemonConfig {
claim_batch_size: 10,
claim_interval_ms: 10,
model_concurrency_limits,
..Default::default()
};
let manager = Arc::new(
PostgresRequestManager::with_client(
TestDbPools::new(pool.clone()).await.unwrap(),
http_client.clone(),
)
.with_config(config),
);
let templates: Vec<_> = (0..5)
.map(|i| RequestTemplateInput {
custom_id: Some(format!("req-{}", i)),
endpoint: "https://api.example.com".to_string(),
method: "POST".to_string(),
path: "/v1/test".to_string(),
body: format!(r#"{{"prompt":"test{}"}}"#, i),
model: "gpt-4".to_string(),
api_key: "test-key".to_string(),
})
.collect();
let file_id = manager
.create_file("pagination-test".to_string(), None, templates)
.await
.expect("Failed to create file");
let batch = manager
.create_batch(BatchInput {
file_id,
endpoint: "/v1/chat/completions".to_string(),
completion_window: "24h".to_string(),
metadata: None,
created_by: None,
api_key_id: None,
api_key: None,
total_requests: None,
})
.await
.expect("Failed to create batch");
let shutdown_token = tokio_util::sync::CancellationToken::new();
manager
.clone()
.run(shutdown_token.clone())
.expect("Failed to start daemon");
let start = tokio::time::Instant::now();
while start.elapsed() < Duration::from_secs(5) {
let requests = manager
.get_batch_requests(batch.id)
.await
.expect("Failed to get requests");
if requests.iter().all(|r| r.is_terminal()) {
break;
}
tokio::time::sleep(Duration::from_millis(50)).await;
}
shutdown_token.cancel();
let all_results = collect_batch_results(&manager, batch.id).await;
assert_eq!(all_results.len(), 5, "Should have 5 results");
let stream = manager.get_batch_results_stream(batch.id, 2, None, None);
let offset_results: Vec<_> = stream.filter_map(|r| async { r.ok() }).collect().await;
assert_eq!(
offset_results.len(),
3,
"Should have 3 results with offset 2"
);
let offset_ids: Vec<_> = offset_results
.iter()
.filter_map(|r| r.custom_id.as_ref())
.collect();
assert!(offset_ids.contains(&&"req-2".to_string()));
assert!(offset_ids.contains(&&"req-3".to_string()));
assert!(offset_ids.contains(&&"req-4".to_string()));
}
#[sqlx::test]
async fn test_batch_results_status_filter(pool: sqlx::PgPool) {
let http_client = Arc::new(MockHttpClient::new());
http_client.add_response(
"POST /v1/test",
Ok(HttpResponse {
status: 200,
body: r#"{"result":"success1"}"#.to_string(),
}),
);
http_client.add_response(
"POST /v1/test",
Ok(HttpResponse {
status: 200,
body: r#"{"result":"success2"}"#.to_string(),
}),
);
http_client.add_response(
"POST /v1/test",
Ok(HttpResponse {
status: 500,
body: r#"{"error":"server error"}"#.to_string(),
}),
);
let model_concurrency_limits = Arc::new(dashmap::DashMap::new());
model_concurrency_limits.insert("gpt-4".to_string(), 10);
let config = DaemonConfig {
claim_batch_size: 10,
claim_interval_ms: 10,
model_concurrency_limits,
max_retries: Some(0), ..Default::default()
};
let manager = Arc::new(
PostgresRequestManager::with_client(
TestDbPools::new(pool.clone()).await.unwrap(),
http_client.clone(),
)
.with_config(config),
);
let templates: Vec<_> = (0..3)
.map(|i| RequestTemplateInput {
custom_id: Some(format!("req-{}", i)),
endpoint: "https://api.example.com".to_string(),
method: "POST".to_string(),
path: "/v1/test".to_string(),
body: format!(r#"{{"prompt":"test{}"}}"#, i),
model: "gpt-4".to_string(),
api_key: "test-key".to_string(),
})
.collect();
let file_id = manager
.create_file("filter-test".to_string(), None, templates)
.await
.expect("Failed to create file");
let batch = manager
.create_batch(BatchInput {
file_id,
endpoint: "/v1/chat/completions".to_string(),
completion_window: "24h".to_string(),
metadata: None,
created_by: None,
api_key_id: None,
api_key: None,
total_requests: None,
})
.await
.expect("Failed to create batch");
let shutdown_token = tokio_util::sync::CancellationToken::new();
manager
.clone()
.run(shutdown_token.clone())
.expect("Failed to start daemon");
let start = tokio::time::Instant::now();
while start.elapsed() < Duration::from_secs(5) {
let requests = manager
.get_batch_requests(batch.id)
.await
.expect("Failed to get requests");
if requests.iter().all(|r| r.is_terminal()) {
break;
}
tokio::time::sleep(Duration::from_millis(50)).await;
}
shutdown_token.cancel();
let stream =
manager.get_batch_results_stream(batch.id, 0, None, Some("completed".to_string()));
let completed_results: Vec<_> = stream.filter_map(|r| async { r.ok() }).collect().await;
assert_eq!(
completed_results.len(),
2,
"Should have 2 completed results"
);
for r in &completed_results {
assert_eq!(r.status, fusillade::batch::BatchResultStatus::Completed);
}
let stream =
manager.get_batch_results_stream(batch.id, 0, None, Some("failed".to_string()));
let failed_results: Vec<_> = stream.filter_map(|r| async { r.ok() }).collect().await;
assert_eq!(failed_results.len(), 1, "Should have 1 failed result");
assert_eq!(
failed_results[0].status,
fusillade::batch::BatchResultStatus::Failed
);
}
#[sqlx::test]
async fn test_batch_results_search_filter(pool: sqlx::PgPool) {
let http_client = Arc::new(MockHttpClient::new());
for _ in 0..3 {
http_client.add_response(
"POST /v1/test",
Ok(HttpResponse {
status: 200,
body: r#"{"result":"success"}"#.to_string(),
}),
);
}
let model_concurrency_limits = Arc::new(dashmap::DashMap::new());
model_concurrency_limits.insert("gpt-4".to_string(), 10);
let config = DaemonConfig {
claim_batch_size: 10,
claim_interval_ms: 10,
model_concurrency_limits,
..Default::default()
};
let manager = Arc::new(
PostgresRequestManager::with_client(
TestDbPools::new(pool.clone()).await.unwrap(),
http_client.clone(),
)
.with_config(config),
);
let file_id = manager
.create_file(
"search-test".to_string(),
None,
vec![
RequestTemplateInput {
custom_id: Some("Alpha-Request".to_string()),
endpoint: "https://api.example.com".to_string(),
method: "POST".to_string(),
path: "/v1/test".to_string(),
body: r#"{"prompt":"test"}"#.to_string(),
model: "gpt-4".to_string(),
api_key: "test-key".to_string(),
},
RequestTemplateInput {
custom_id: Some("Beta-Request".to_string()),
endpoint: "https://api.example.com".to_string(),
method: "POST".to_string(),
path: "/v1/test".to_string(),
body: r#"{"prompt":"test"}"#.to_string(),
model: "gpt-4".to_string(),
api_key: "test-key".to_string(),
},
RequestTemplateInput {
custom_id: Some("Gamma-Item".to_string()),
endpoint: "https://api.example.com".to_string(),
method: "POST".to_string(),
path: "/v1/test".to_string(),
body: r#"{"prompt":"test"}"#.to_string(),
model: "gpt-4".to_string(),
api_key: "test-key".to_string(),
},
],
)
.await
.expect("Failed to create file");
let batch = manager
.create_batch(BatchInput {
file_id,
endpoint: "/v1/chat/completions".to_string(),
completion_window: "24h".to_string(),
metadata: None,
created_by: None,
api_key_id: None,
api_key: None,
total_requests: None,
})
.await
.expect("Failed to create batch");
let shutdown_token = tokio_util::sync::CancellationToken::new();
manager
.clone()
.run(shutdown_token.clone())
.expect("Failed to start daemon");
let start = tokio::time::Instant::now();
while start.elapsed() < Duration::from_secs(5) {
let requests = manager
.get_batch_requests(batch.id)
.await
.expect("Failed to get requests");
if requests.iter().all(|r| r.is_terminal()) {
break;
}
tokio::time::sleep(Duration::from_millis(50)).await;
}
shutdown_token.cancel();
let stream =
manager.get_batch_results_stream(batch.id, 0, Some("request".to_string()), None);
let search_results: Vec<_> = stream.filter_map(|r| async { r.ok() }).collect().await;
assert_eq!(
search_results.len(),
2,
"Should find 2 results containing 'request'"
);
let custom_ids: Vec<_> = search_results
.iter()
.filter_map(|r| r.custom_id.as_ref())
.collect();
assert!(custom_ids.contains(&&"Alpha-Request".to_string()));
assert!(custom_ids.contains(&&"Beta-Request".to_string()));
let stream = manager.get_batch_results_stream(batch.id, 0, Some("ALPHA".to_string()), None);
let alpha_results: Vec<_> = stream.filter_map(|r| async { r.ok() }).collect().await;
assert_eq!(alpha_results.len(), 1, "Should find 1 result for 'ALPHA'");
assert_eq!(
alpha_results[0].custom_id,
Some("Alpha-Request".to_string())
);
}
#[sqlx::test]
#[test_log::test]
async fn test_retry_failed_requests_for_batch_retries_all(pool: sqlx::PgPool) {
let http_client = Arc::new(MockHttpClient::new());
for _ in 0..2 {
http_client.add_response(
"POST /v1/test",
Ok(HttpResponse {
status: 429,
body: "rate limited".to_string(),
}),
);
}
for _ in 0..2 {
http_client.add_response(
"POST /v1/test",
Ok(HttpResponse {
status: 400,
body: "bad request".to_string(),
}),
);
}
for _ in 0..4 {
http_client.add_response(
"POST /v1/test",
Ok(HttpResponse {
status: 200,
body: r#"{"result":"success"}"#.to_string(),
}),
);
}
let model_concurrency_limits = Arc::new(dashmap::DashMap::new());
model_concurrency_limits.insert("test-model".to_string(), 10);
let config = DaemonConfig {
claim_batch_size: 10,
claim_interval_ms: 10,
model_concurrency_limits,
max_retries: Some(0), stop_before_deadline_ms: None,
backoff_ms: 100,
backoff_factor: 2,
max_backoff_ms: 1000,
status_log_interval_ms: None,
heartbeat_interval_ms: 10000,
should_retry: Arc::new(default_should_retry),
claim_timeout_ms: 60000,
processing_timeout_ms: 600000,
cancellation_poll_interval_ms: 100,
..Default::default()
};
let manager = Arc::new(
PostgresRequestManager::with_client(
TestDbPools::new(pool.clone()).await.unwrap(),
http_client.clone(),
)
.with_config(config),
);
let templates = (0..4)
.map(|i| fusillade::RequestTemplateInput {
custom_id: Some(format!("req-{}", i)),
endpoint: "https://api.example.com".to_string(),
method: "POST".to_string(),
path: "/v1/test".to_string(),
body: r#"{"test":"data"}"#.to_string(),
model: "test-model".to_string(),
api_key: "test-key".to_string(),
})
.collect();
let file_id = manager
.create_file("test-file".to_string(), None, templates)
.await
.expect("Failed to create file");
let batch = manager
.create_batch(fusillade::batch::BatchInput {
file_id,
endpoint: "/v1/chat/completions".to_string(),
completion_window: "1h".to_string(),
metadata: None,
created_by: None,
api_key_id: None,
api_key: None,
total_requests: None,
})
.await
.expect("Failed to create batch");
let shutdown_token = tokio_util::sync::CancellationToken::new();
manager
.clone()
.run(shutdown_token.clone())
.expect("Failed to start daemon");
let start = tokio::time::Instant::now();
while start.elapsed() < Duration::from_secs(5) {
let status = manager
.get_batch_status(batch.id)
.await
.expect("Failed to get status");
if status.failed_requests == 4 {
break;
}
tokio::time::sleep(Duration::from_millis(10)).await;
}
let status_before = manager
.get_batch_status(batch.id)
.await
.expect("Failed to get status");
assert_eq!(status_before.failed_requests, 4);
let retried_count = manager
.retry_failed_requests_for_batch(batch.id)
.await
.expect("Failed to retry batch");
assert_eq!(retried_count, 4, "Should retry all 4 failed requests");
let start = tokio::time::Instant::now();
while start.elapsed() < Duration::from_secs(5) {
let status = manager
.get_batch_status(batch.id)
.await
.expect("Failed to get status");
if status.completed_requests == 4 {
break;
}
tokio::time::sleep(Duration::from_millis(10)).await;
}
shutdown_token.cancel();
let status_after = manager
.get_batch_status(batch.id)
.await
.expect("Failed to get status");
assert_eq!(
status_after.completed_requests, 4,
"All 4 requests should complete after retry"
);
assert_eq!(status_after.failed_requests, 0, "No failed requests remain");
}
}
mod queue_counts {
use super::*;
use fusillade::request::DaemonId;
use uuid::Uuid;
#[sqlx::test]
async fn test_pending_queue_counts_by_model_and_completion_window(pool: sqlx::PgPool) {
let http_client = Arc::new(MockHttpClient::new());
let manager = Arc::new(PostgresRequestManager::with_client(
TestDbPools::new(pool.clone()).await.unwrap(),
http_client,
));
let file_24h = manager
.create_file(
"file-24h".to_string(),
None,
vec![
RequestTemplateInput {
custom_id: None,
endpoint: "https://api.example.com".to_string(),
method: "POST".to_string(),
path: "/v1/test".to_string(),
body: "{}".to_string(),
model: "gpt-4".to_string(),
api_key: "k".to_string(),
},
RequestTemplateInput {
custom_id: None,
endpoint: "https://api.example.com".to_string(),
method: "POST".to_string(),
path: "/v1/test".to_string(),
body: "{}".to_string(),
model: "gpt-4".to_string(),
api_key: "k".to_string(),
},
RequestTemplateInput {
custom_id: None,
endpoint: "https://api.example.com".to_string(),
method: "POST".to_string(),
path: "/v1/test".to_string(),
body: "{}".to_string(),
model: "gpt-3.5".to_string(),
api_key: "k".to_string(),
},
],
)
.await
.unwrap();
let batch_24h = manager
.create_batch(BatchInput {
file_id: file_24h,
endpoint: "/v1/chat/completions".to_string(),
completion_window: "24h".to_string(),
metadata: None,
created_by: None,
api_key_id: None,
api_key: None,
total_requests: None,
})
.await
.unwrap();
let file_1h = manager
.create_file(
"file-1h".to_string(),
None,
vec![
RequestTemplateInput {
custom_id: None,
endpoint: "https://api.example.com".to_string(),
method: "POST".to_string(),
path: "/v1/test".to_string(),
body: "{}".to_string(),
model: "gpt-4".to_string(),
api_key: "k".to_string(),
},
RequestTemplateInput {
custom_id: None,
endpoint: "https://api.example.com".to_string(),
method: "POST".to_string(),
path: "/v1/test".to_string(),
body: "{}".to_string(),
model: "gpt-3.5".to_string(),
api_key: "k".to_string(),
},
RequestTemplateInput {
custom_id: None,
endpoint: "https://api.example.com".to_string(),
method: "POST".to_string(),
path: "/v1/test".to_string(),
body: "{}".to_string(),
model: "gpt-3.5".to_string(),
api_key: "k".to_string(),
},
],
)
.await
.unwrap();
let _batch_1h = manager
.create_batch(BatchInput {
file_id: file_1h,
endpoint: "/v1/chat/completions".to_string(),
completion_window: "1h".to_string(),
metadata: None,
created_by: None,
api_key_id: None,
api_key: None,
total_requests: None,
})
.await
.unwrap();
let reqs_24h = manager.get_batch_requests(batch_24h.id).await.unwrap();
let to_claim = reqs_24h
.iter()
.find(|r| r.data().model == "gpt-4")
.expect("Expected a gpt-4 request")
.id();
let daemon_id = DaemonId(Uuid::new_v4());
sqlx::query!(
"UPDATE requests SET state = 'claimed', daemon_id = $2, claimed_at = NOW() WHERE id = $1",
*to_claim as Uuid,
*daemon_id as Uuid,
)
.execute(&pool)
.await
.unwrap();
let counts = manager
.get_pending_request_counts_by_model_and_completion_window(
&[("1h".to_string(), 3600), ("24h".to_string(), 24 * 3600)],
&["pending".to_string()],
&[],
false,
)
.await
.unwrap();
let mut expected: std::collections::HashMap<
String,
std::collections::HashMap<String, i64>,
> = std::collections::HashMap::new();
expected
.entry("gpt-3.5".to_string())
.or_default()
.insert("1h".to_string(), 2);
expected
.entry("gpt-3.5".to_string())
.or_default()
.insert("24h".to_string(), 3);
expected
.entry("gpt-4".to_string())
.or_default()
.insert("1h".to_string(), 1);
expected
.entry("gpt-4".to_string())
.or_default()
.insert("24h".to_string(), 2);
assert_eq!(counts, expected);
}
}