1use std::collections::HashMap;
3use std::sync::Arc;
4use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
5use std::time::Duration;
6
7use metrics::{counter, gauge, histogram};
8use tokio::sync::{RwLock, Semaphore};
9use tokio::task::JoinSet;
10
11use crate::FusilladeError;
12use crate::batch::BatchId;
13use crate::error::Result;
14use crate::http::{HttpClient, HttpResponse};
15use crate::manager::{DaemonStorage, Storage};
16use crate::request::{DaemonId, RequestCompletionResult};
17
18pub mod transitions;
19pub mod types;
20
21pub use types::{
22 AnyDaemonRecord, DaemonData, DaemonRecord, DaemonState, DaemonStats, DaemonStatus, Dead,
23 Initializing, Running,
24};
25
26pub type ShouldRetryFn = Arc<dyn Fn(&HttpResponse) -> bool + Send + Sync>;
30
31type SemaphoreEntry = (Arc<Semaphore>, usize);
33
34pub fn default_should_retry(response: &HttpResponse) -> bool {
36 response.status >= 500 || response.status == 429 || response.status == 408
37}
38
39fn default_should_retry_fn() -> ShouldRetryFn {
41 Arc::new(default_should_retry)
42}
43
44fn default_model_escalations() -> Arc<dashmap::DashMap<String, ModelEscalationConfig>> {
46 Arc::new(dashmap::DashMap::new())
47}
48
49fn default_escalation_threshold_seconds() -> i64 {
53 900
54}
55
56#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
64pub struct ModelEscalationConfig {
65 pub escalation_model: String,
67
68 #[serde(default = "default_escalation_threshold_seconds")]
71 pub escalation_threshold_seconds: i64,
72}
73
74#[derive(Clone, serde::Serialize, serde::Deserialize)]
76pub struct DaemonConfig {
77 pub claim_batch_size: usize,
79
80 pub default_model_concurrency: usize,
82
83 pub model_concurrency_limits: Arc<dashmap::DashMap<String, usize>>,
85
86 #[serde(skip, default = "default_model_escalations")]
90 pub model_escalations: Arc<dashmap::DashMap<String, ModelEscalationConfig>>,
91
92 pub claim_interval_ms: u64,
94
95 pub max_retries: Option<u32>,
97
98 pub stop_before_deadline_ms: Option<i64>,
103
104 pub backoff_ms: u64,
106
107 pub backoff_factor: u64,
109
110 pub max_backoff_ms: u64,
112
113 pub timeout_ms: u64,
115
116 pub status_log_interval_ms: Option<u64>,
119
120 pub heartbeat_interval_ms: u64,
122
123 #[serde(skip, default = "default_should_retry_fn")]
126 pub should_retry: ShouldRetryFn,
127
128 pub claim_timeout_ms: u64,
131
132 pub processing_timeout_ms: u64,
135
136 pub unclaim_batch_size: usize,
139
140 pub cancellation_poll_interval_ms: u64,
143
144 #[serde(default = "default_batch_metadata_fields")]
151 pub batch_metadata_fields: Vec<String>,
152}
153
154fn default_batch_metadata_fields() -> Vec<String> {
155 vec![
156 "id".to_string(),
157 "endpoint".to_string(),
158 "created_at".to_string(),
159 "completion_window".to_string(),
160 ]
161}
162
163impl Default for DaemonConfig {
164 fn default() -> Self {
165 Self {
166 claim_batch_size: 100,
167 default_model_concurrency: 10,
168 model_concurrency_limits: Arc::new(dashmap::DashMap::new()),
169 model_escalations: default_model_escalations(),
170 claim_interval_ms: 1000,
171 max_retries: Some(1000),
172 stop_before_deadline_ms: Some(900_000),
173 backoff_ms: 1000,
174 backoff_factor: 2,
175 max_backoff_ms: 10000,
176 timeout_ms: 600000,
177 status_log_interval_ms: Some(2000), heartbeat_interval_ms: 10000, should_retry: Arc::new(default_should_retry),
180 claim_timeout_ms: 60000, processing_timeout_ms: 600000, unclaim_batch_size: 100, cancellation_poll_interval_ms: 5000, batch_metadata_fields: default_batch_metadata_fields(),
185 }
186 }
187}
188
189pub struct Daemon<S, H>
194where
195 S: Storage + DaemonStorage,
196 H: HttpClient,
197{
198 daemon_id: DaemonId,
199 storage: Arc<S>,
200 http_client: Arc<H>,
201 config: DaemonConfig,
202 semaphores: Arc<RwLock<HashMap<String, SemaphoreEntry>>>,
203 requests_in_flight: Arc<AtomicUsize>,
204 requests_processed: Arc<AtomicU64>,
205 requests_failed: Arc<AtomicU64>,
206 shutdown_token: tokio_util::sync::CancellationToken,
207 cancellation_tokens: Arc<dashmap::DashMap<BatchId, tokio_util::sync::CancellationToken>>,
210}
211
212impl<S, H> Daemon<S, H>
213where
214 S: Storage + DaemonStorage + 'static,
215 H: HttpClient + 'static,
216{
217 pub fn new(
219 storage: Arc<S>,
220 http_client: Arc<H>,
221 config: DaemonConfig,
222 shutdown_token: tokio_util::sync::CancellationToken,
223 ) -> Self {
224 Self {
225 daemon_id: DaemonId::from(uuid::Uuid::new_v4()),
226 storage,
227 http_client,
228 config,
229 semaphores: Arc::new(RwLock::new(HashMap::new())),
230 requests_in_flight: Arc::new(AtomicUsize::new(0)),
231 requests_processed: Arc::new(AtomicU64::new(0)),
232 requests_failed: Arc::new(AtomicU64::new(0)),
233 shutdown_token,
234 cancellation_tokens: Arc::new(dashmap::DashMap::new()),
235 }
236 }
237
238 async fn get_semaphore(&self, model: &str) -> Arc<Semaphore> {
245 let current_limit = self
246 .config
247 .model_concurrency_limits
248 .get(model)
249 .map(|entry| *entry.value())
250 .unwrap_or(self.config.default_model_concurrency);
251
252 let mut semaphores = self.semaphores.write().await;
253
254 let entry = semaphores
255 .entry(model.to_string())
256 .or_insert_with(|| (Arc::new(Semaphore::new(current_limit)), current_limit));
257
258 let (semaphore, stored_limit) = entry;
259
260 if *stored_limit != current_limit {
262 if current_limit > *stored_limit {
263 let delta = current_limit - *stored_limit;
265 semaphore.add_permits(delta);
266 tracing::info!(
267 model = %model,
268 old_limit = *stored_limit,
269 new_limit = current_limit,
270 added_permits = delta,
271 "Increased model concurrency limit"
272 );
273 *stored_limit = current_limit;
274 } else {
275 let desired_delta = *stored_limit - current_limit;
277 let actual_forgotten = semaphore.forget_permits(desired_delta);
278
279 if actual_forgotten < desired_delta {
280 tracing::warn!(
281 model = %model,
282 old_limit = *stored_limit,
283 target_limit = current_limit,
284 desired_to_forget = desired_delta,
285 actually_forgot = actual_forgotten,
286 held_permits = desired_delta - actual_forgotten,
287 "Decreased model concurrency limit (some permits still held by in-flight requests)"
288 );
289 } else {
290 tracing::info!(
291 model = %model,
292 old_limit = *stored_limit,
293 new_limit = current_limit,
294 forgot_permits = actual_forgotten,
295 "Decreased model concurrency limit"
296 );
297 }
298
299 *stored_limit = current_limit + (desired_delta - actual_forgotten);
301 }
302 }
303
304 semaphore.clone()
305 }
306
307 async fn try_acquire_permit(&self, model: &str) -> Option<tokio::sync::OwnedSemaphorePermit> {
309 let semaphore = self.get_semaphore(model).await;
310 semaphore.clone().try_acquire_owned().ok()
311 }
312
313 #[tracing::instrument(skip(self), fields(daemon_id = %self.daemon_id))]
320 pub async fn run(self: Arc<Self>) -> Result<()> {
321 tracing::info!("Daemon starting main processing loop");
322
323 let daemon_record = DaemonRecord {
325 data: DaemonData {
326 id: self.daemon_id,
327 hostname: types::get_hostname(),
328 pid: types::get_pid(),
329 version: types::get_version(),
330 config_snapshot: serde_json::to_value(&self.config)
331 .expect("Failed to serialize daemon config"),
332 },
333 state: Initializing {
334 started_at: chrono::Utc::now(),
335 },
336 };
337
338 let running_record = daemon_record.start(self.storage.as_ref()).await?;
339 tracing::info!("Daemon registered in database");
340
341 let storage = self.storage.clone();
343 let requests_in_flight = self.requests_in_flight.clone();
344 let requests_processed = self.requests_processed.clone();
345 let requests_failed = self.requests_failed.clone();
346 let daemon_id = self.daemon_id;
347 let heartbeat_interval_ms = self.config.heartbeat_interval_ms;
348 let shutdown_signal = self.shutdown_token.clone();
349
350 let heartbeat_handle = tokio::spawn(async move {
351 let mut interval = tokio::time::interval(Duration::from_millis(heartbeat_interval_ms));
352 let mut daemon_record = running_record;
353
354 loop {
355 tokio::select! {
356 _ = interval.tick() => {
357 let stats = DaemonStats {
358 requests_processed: requests_processed.load(Ordering::Relaxed),
359 requests_failed: requests_failed.load(Ordering::Relaxed),
360 requests_in_flight: requests_in_flight.load(Ordering::Relaxed),
361 };
362
363 let current = daemon_record.clone();
365 match current.heartbeat(stats, storage.as_ref()).await {
366 Ok(updated) => {
367 daemon_record = updated;
368 tracing::trace!(
369 daemon_id = %daemon_id,
370 "Heartbeat sent"
371 );
372 }
373 Err(e) => {
374 tracing::error!(
375 daemon_id = %daemon_id,
376 error = %e,
377 "Failed to send heartbeat"
378 );
379 }
381 }
382 }
383 _ = shutdown_signal.cancelled() => {
384 tracing::info!("Shutting down heartbeat task");
386 if let Err(e) = daemon_record.shutdown(storage.as_ref()).await {
387 tracing::error!(
388 daemon_id = %daemon_id,
389 error = %e,
390 "Failed to mark daemon as dead during shutdown"
391 );
392 }
393 break;
394 }
395 }
396 }
397 });
398
399 if let Some(interval_ms) = self.config.status_log_interval_ms {
401 let requests_in_flight = self.requests_in_flight.clone();
402 let daemon_id = self.daemon_id;
403 tokio::spawn(async move {
404 let mut interval = tokio::time::interval(Duration::from_millis(interval_ms));
405 loop {
406 interval.tick().await;
407 let count = requests_in_flight.load(Ordering::Relaxed);
408 tracing::debug!(
409 daemon_id = %daemon_id,
410 requests_in_flight = count,
411 "Daemon status"
412 );
413 }
414 });
415 }
416
417 let cancellation_tokens = self.cancellation_tokens.clone();
419 let storage = self.storage.clone();
420 let shutdown_token = self.shutdown_token.clone();
421 let cancellation_poll_interval_ms = self.config.cancellation_poll_interval_ms;
422 tokio::spawn(async move {
423 let mut interval =
424 tokio::time::interval(Duration::from_millis(cancellation_poll_interval_ms));
425 tracing::info!(
426 interval_ms = cancellation_poll_interval_ms,
427 "Batch cancellation polling started"
428 );
429
430 loop {
431 tokio::select! {
432 _ = interval.tick() => {
433 let active_batch_ids: Vec<BatchId> = cancellation_tokens
435 .iter()
436 .map(|entry| *entry.key())
437 .collect();
438
439 if active_batch_ids.is_empty() {
440 continue;
441 }
442
443 for batch_id in active_batch_ids {
447 if let Ok(batch) = storage.get_batch(batch_id).await
449 && batch.cancelling_at.is_some()
450 && let Some(entry) = cancellation_tokens.get(&batch_id) {
451 entry.value().cancel();
452 tracing::info!(batch_id = %batch_id, "Cancelled all requests in batch");
453 drop(entry);
455 cancellation_tokens.remove(&batch_id);
456 }
457 }
458 }
459 _ = shutdown_token.cancelled() => {
460 tracing::info!("Shutting down cancellation polling");
461 break;
462 }
463 }
464 }
465 });
466
467 let mut join_set: JoinSet<Result<()>> = JoinSet::new();
468
469 let run_result = loop {
470 if self.shutdown_token.is_cancelled() {
472 tracing::info!("Shutdown signal received, stopping daemon");
473 break Ok(());
474 }
475
476 while let Some(result) = join_set.try_join_next() {
478 match result {
479 Ok(Ok(())) => {
480 tracing::trace!("Task completed successfully");
481 }
482 Ok(Err(e)) => {
483 tracing::error!(error = %e, "Task failed");
484 }
485 Err(join_error) => {
486 tracing::error!(error = %join_error, "Task panicked");
487 }
488 }
489 }
490
491 tracing::trace!("Sleeping before claiming");
492 tokio::select! {
493 _ = tokio::time::sleep(Duration::from_millis(self.config.claim_interval_ms)) => {},
494 _ = self.shutdown_token.cancelled() => {
495 tracing::info!("Shutdown signal received, stopping daemon");
496 break Ok(());
497 }
498 }
499 let mut claimed = self
501 .storage
502 .claim_requests(self.config.claim_batch_size, self.daemon_id)
503 .await?;
504
505 counter!("fusillade_claims_total").increment(claimed.len() as u64);
507
508 tracing::debug!(
509 claimed_count = claimed.len(),
510 "Claimed requests from storage"
511 );
512
513 for request in &mut claimed {
518 if let Some(config) = self.config.model_escalations.get(&request.data.model) {
519 let time_remaining = request.state.batch_expires_at - chrono::Utc::now();
520 if time_remaining.num_seconds() < config.escalation_threshold_seconds {
521 let original_model = request.data.model.clone();
522 request.data.model = config.escalation_model.clone();
523
524 if let Ok(mut json) =
526 serde_json::from_str::<serde_json::Value>(&request.data.body)
527 && let Some(obj) = json.as_object_mut()
528 {
529 obj.insert(
530 "model".to_string(),
531 serde_json::Value::String(config.escalation_model.clone()),
532 );
533 if let Ok(new_body) = serde_json::to_string(&json) {
534 request.data.body = new_body;
535 }
536 }
537
538 counter!("fusillade_requests_routed_to_escalation_total", "original_model" => original_model.clone(), "escalation_model" => config.escalation_model.clone()).increment(1);
541 tracing::info!(
542 request_id = %request.data.id,
543 original_model = %original_model,
544 escalation_model = %config.escalation_model,
545 time_remaining_seconds = time_remaining.num_seconds(),
546 threshold_seconds = config.escalation_threshold_seconds,
547 "Routing request to escalation model due to time pressure"
548 );
549 }
550 }
551 }
552
553 let mut by_model: HashMap<String, Vec<_>> = HashMap::new();
555 for request in claimed {
556 let model = request.data.model.clone();
557 by_model.entry(model).or_default().push(request);
558 }
559
560 tracing::debug!(
561 models = by_model.len(),
562 total_requests = by_model.values().map(|v| v.len()).sum::<usize>(),
563 "Grouped requests by model"
564 );
565
566 for (model, requests) in by_model {
568 tracing::debug!(model = %model, count = requests.len(), "Processing requests for model");
569
570 for request in requests {
571 let request_id = request.data.id;
572 let batch_id = request.data.batch_id;
573
574 match self.try_acquire_permit(&model).await {
576 Some(permit) => {
577 tracing::debug!(
578 request_id = %request_id,
579 batch_id = %batch_id,
580 model = %model,
581 "Acquired permit, spawning processing task"
582 );
583
584 let model_clone = model.clone(); let storage = self.storage.clone();
587 let http_client = (*self.http_client).clone();
588 let timeout_ms = self.config.timeout_ms;
589 let retry_config = (&self.config).into();
590 let requests_in_flight = self.requests_in_flight.clone();
591 let requests_processed = self.requests_processed.clone();
592 let requests_failed = self.requests_failed.clone();
593 let should_retry = self.config.should_retry.clone();
594 let shutdown_token = self.shutdown_token.clone();
595 let cancellation_tokens = self.cancellation_tokens.clone();
596
597 let batch_cancellation_token =
600 cancellation_tokens.entry(batch_id).or_default().clone();
601
602 requests_in_flight.fetch_add(1, Ordering::Relaxed);
604 gauge!("fusillade_requests_in_flight", "model" => model_clone.clone())
605 .increment(1.0);
606
607 join_set.spawn(async move {
608 let _permit = permit;
610
611 let processing_start = std::time::Instant::now();
613
614 let model_for_guard = model_clone.clone();
616 let _guard = scopeguard::guard((), move |_| {
617 requests_in_flight.fetch_sub(1, Ordering::Relaxed);
618 gauge!("fusillade_requests_in_flight", "model" => model_for_guard).decrement(1.0);
619 });
620
621 tracing::info!(request_id = %request_id, "Processing request");
622
623 let processing = request.process(
625 http_client,
626 timeout_ms,
627 storage.as_ref()
628 ).await?;
629
630 let retry_attempt_at_completion = processing.state.retry_attempt;
632
633 let cancellation = async {
634 tokio::select! {
635 _ = batch_cancellation_token.cancelled() => {
636 crate::request::transitions::CancellationReason::User
637 }
638 _ = shutdown_token.cancelled() => {
639 crate::request::transitions::CancellationReason::Shutdown
640 }
641 }
642 };
643
644 match processing.complete(storage.as_ref(), |response| {
646 (should_retry)(response)
647 }, cancellation).await {
648 Ok(RequestCompletionResult::Completed(_completed)) => {
649 requests_processed.fetch_add(1, Ordering::Relaxed);
650 counter!("fusillade_requests_completed_total", "model" => model_clone.clone(), "status" => "success").increment(1);
651 histogram!("fusillade_request_duration_seconds", "model" => model_clone.clone(), "status" => "success")
652 .record(processing_start.elapsed().as_secs_f64());
653 histogram!("fusillade_retry_attempts_on_success", "model" => model_clone.clone())
655 .record(retry_attempt_at_completion as f64);
656 tracing::info!(request_id = %request_id, retry_attempts = retry_attempt_at_completion, "Request completed successfully");
657 }
658 Ok(RequestCompletionResult::Failed(failed)) => {
659 let retry_attempt = failed.state.retry_attempt;
660
661 if failed.state.reason.is_retriable() {
663 tracing::warn!(
664 request_id = %request_id,
665 retry_attempt,
666 error = %failed.state.reason.to_error_message(),
667 "Request failed with retriable error, attempting retry"
668 );
669
670 match failed.can_retry(retry_attempt, retry_config) {
672 Ok(pending) => {
673 storage.persist(&pending).await?;
675 counter!(
676 "fusillade_requests_retried_total",
677 "model" => model_clone.clone(),
678 "attempt" => (retry_attempt + 1).to_string()
679 ).increment(1);
680 tracing::info!(
681 request_id = %request_id,
682 retry_attempt = retry_attempt + 1,
683 "Request queued for retry"
684 );
685 }
686 Err(failed) => {
687 storage.persist(&*failed).await?;
689 requests_failed.fetch_add(1, Ordering::Relaxed);
690 counter!("fusillade_requests_completed_total", "model" => model_clone.clone(), "status" => "failed").increment(1);
691 histogram!("fusillade_request_duration_seconds", "model" => model_clone.clone(), "status" => "failed")
692 .record(processing_start.elapsed().as_secs_f64());
693 tracing::warn!(
694 request_id = %request_id,
695 retry_attempt,
696 "Request failed permanently (no retries remaining)"
697 );
698 }
699 }
700 } else {
701 requests_failed.fetch_add(1, Ordering::Relaxed);
702 counter!("fusillade_requests_completed_total", "model" => model_clone.clone(), "status" => "failed").increment(1);
703 histogram!("fusillade_request_duration_seconds", "model" => model_clone.clone(), "status" => "failed")
704 .record(processing_start.elapsed().as_secs_f64());
705 tracing::warn!(
706 request_id = %request_id,
707 error = %failed.state.reason.to_error_message(),
708 "Request failed with non-retriable error, not retrying"
709 );
710 }
711 }
712 Ok(RequestCompletionResult::Canceled(_canceled)) => {
713 counter!("fusillade_requests_completed_total", "model" => model_clone.clone(), "status" => "cancelled").increment(1);
714 tracing::debug!(request_id = %request_id, "Request canceled by user");
715 }
716 Err(FusilladeError::Shutdown) => {
717 tracing::info!(request_id = %request_id, "Request aborted due to shutdown");
718 }
720 Err(e) => {
721 tracing::error!(request_id = %request_id, error = %e, "Unexpected error processing request");
723 return Err(e);
724 }
725 }
726
727 Ok(())
732 });
733 }
734 None => {
735 tracing::debug!(
736 request_id = %request_id,
737 model = %model,
738 "No capacity available, unclaiming request"
739 );
740
741 let storage = self.storage.clone();
743 if let Err(e) = request.unclaim(storage.as_ref()).await {
744 tracing::error!(
745 request_id = %request_id,
746 error = %e,
747 "Failed to unclaim request"
748 );
749 };
750 }
751 }
752 }
753 }
754 };
755
756 tracing::info!("Waiting for heartbeat task to complete");
758 if let Err(e) = heartbeat_handle.await {
759 tracing::error!(error = %e, "Heartbeat task panicked");
760 }
761
762 run_result
763 }
764}
765
766#[cfg(test)]
767mod tests {
768 use super::*;
769 use crate::TestDbPools;
770 use crate::http::{HttpResponse, MockHttpClient};
771 use crate::manager::{DaemonExecutor, postgres::PostgresRequestManager};
772 use std::time::Duration;
773
774 #[sqlx::test]
775 #[test_log::test]
776 async fn test_daemon_claims_and_completes_request(pool: sqlx::PgPool) {
777 let http_client = Arc::new(MockHttpClient::new());
779 http_client.add_response(
780 "POST /v1/test",
781 Ok(HttpResponse {
782 status: 200,
783 body: r#"{"result":"success"}"#.to_string(),
784 }),
785 );
786
787 let config = DaemonConfig {
789 claim_batch_size: 10,
790 claim_interval_ms: 10, default_model_concurrency: 10,
792 model_concurrency_limits: Arc::new(dashmap::DashMap::new()),
793 model_escalations: Arc::new(dashmap::DashMap::new()),
794 max_retries: Some(3),
795 stop_before_deadline_ms: None,
796 backoff_ms: 100,
797 backoff_factor: 2,
798 max_backoff_ms: 1000,
799 timeout_ms: 5000,
800 status_log_interval_ms: None, heartbeat_interval_ms: 10000, should_retry: Arc::new(default_should_retry),
803 claim_timeout_ms: 60000,
804 processing_timeout_ms: 600000,
805 unclaim_batch_size: 100,
806 batch_metadata_fields: vec![],
807 cancellation_poll_interval_ms: 100, };
809
810 let manager = Arc::new(
811 PostgresRequestManager::with_client(
812 TestDbPools::new(pool.clone()).await.unwrap(),
813 http_client.clone(),
814 )
815 .with_config(config),
816 );
817
818 let file_id = manager
820 .create_file(
821 "test-file".to_string(),
822 Some("Test file".to_string()),
823 vec![crate::RequestTemplateInput {
824 custom_id: None,
825 endpoint: "https://api.example.com".to_string(),
826 method: "POST".to_string(),
827 path: "/v1/test".to_string(),
828 body: r#"{"prompt":"test"}"#.to_string(),
829 model: "test-model".to_string(),
830 api_key: "test-key".to_string(),
831 }],
832 )
833 .await
834 .expect("Failed to create file");
835
836 let batch = manager
837 .create_batch(crate::batch::BatchInput {
838 file_id,
839 endpoint: "/v1/chat/completions".to_string(),
840 completion_window: "24h".to_string(),
841 metadata: None,
842 created_by: None,
843 })
844 .await
845 .expect("Failed to create batch");
846
847 let requests = manager
849 .get_batch_requests(batch.id)
850 .await
851 .expect("Failed to get batch requests");
852 assert_eq!(requests.len(), 1);
853 let request_id = requests[0].id();
854
855 let shutdown_token = tokio_util::sync::CancellationToken::new();
857 manager
858 .clone()
859 .run(shutdown_token.clone())
860 .expect("Failed to start daemon");
861
862 let start = tokio::time::Instant::now();
864 let timeout = Duration::from_secs(5);
865 let mut completed = false;
866
867 while start.elapsed() < timeout {
868 let results = manager
869 .get_requests(vec![request_id])
870 .await
871 .expect("Failed to get request");
872
873 if let Some(Ok(any_request)) = results.first()
874 && any_request.is_terminal()
875 {
876 if let crate::AnyRequest::Completed(req) = any_request {
877 assert_eq!(req.state.response_status, 200);
879 assert_eq!(req.state.response_body, r#"{"result":"success"}"#);
880 completed = true;
881 break;
882 } else {
883 panic!(
884 "Request reached terminal state but was not completed: {:?}",
885 any_request
886 );
887 }
888 }
889
890 tokio::time::sleep(Duration::from_millis(50)).await;
891 }
892
893 shutdown_token.cancel();
895
896 assert!(
898 completed,
899 "Request did not complete within timeout. Check daemon processing."
900 );
901
902 assert_eq!(http_client.call_count(), 1);
904 let calls = http_client.get_calls();
905 assert_eq!(calls[0].method, "POST");
906 assert_eq!(calls[0].path, "/v1/test");
907 assert_eq!(calls[0].api_key, "test-key");
908 }
909
910 #[sqlx::test]
911 async fn test_daemon_respects_per_model_concurrency_limits(pool: sqlx::PgPool) {
912 let http_client = Arc::new(MockHttpClient::new());
914
915 let trigger1 = http_client.add_response_with_trigger(
917 "POST /v1/test",
918 Ok(HttpResponse {
919 status: 200,
920 body: r#"{"result":"1"}"#.to_string(),
921 }),
922 );
923 let trigger2 = http_client.add_response_with_trigger(
924 "POST /v1/test",
925 Ok(HttpResponse {
926 status: 200,
927 body: r#"{"result":"2"}"#.to_string(),
928 }),
929 );
930 let trigger3 = http_client.add_response_with_trigger(
931 "POST /v1/test",
932 Ok(HttpResponse {
933 status: 200,
934 body: r#"{"result":"3"}"#.to_string(),
935 }),
936 );
937 let trigger4 = http_client.add_response_with_trigger(
938 "POST /v1/test",
939 Ok(HttpResponse {
940 status: 200,
941 body: r#"{"result":"4"}"#.to_string(),
942 }),
943 );
944 let trigger5 = http_client.add_response_with_trigger(
945 "POST /v1/test",
946 Ok(HttpResponse {
947 status: 200,
948 body: r#"{"result":"5"}"#.to_string(),
949 }),
950 );
951
952 let model_concurrency_limits = Arc::new(dashmap::DashMap::new());
954 model_concurrency_limits.insert("gpt-4".to_string(), 2);
955
956 let config = DaemonConfig {
957 claim_batch_size: 10,
958 claim_interval_ms: 10,
959 default_model_concurrency: 10,
960 model_concurrency_limits,
961 model_escalations: Arc::new(dashmap::DashMap::new()),
962 max_retries: Some(3),
963 stop_before_deadline_ms: None,
964 backoff_ms: 100,
965 backoff_factor: 2,
966 max_backoff_ms: 1000,
967 timeout_ms: 5000,
968 status_log_interval_ms: None,
969 heartbeat_interval_ms: 10000,
970 should_retry: Arc::new(default_should_retry),
971 claim_timeout_ms: 60000,
972 processing_timeout_ms: 600000,
973 unclaim_batch_size: 100,
974 batch_metadata_fields: vec![],
975 cancellation_poll_interval_ms: 100, };
977
978 let manager = Arc::new(
979 PostgresRequestManager::with_client(
980 TestDbPools::new(pool.clone()).await.unwrap(),
981 http_client.clone(),
982 )
983 .with_config(config),
984 );
985
986 let file_id = manager
988 .create_file(
989 "test-file".to_string(),
990 Some("Test concurrency limits".to_string()),
991 vec![
992 crate::RequestTemplateInput {
993 custom_id: None,
994 endpoint: "https://api.example.com".to_string(),
995 method: "POST".to_string(),
996 path: "/v1/test".to_string(),
997 body: r#"{"prompt":"test1"}"#.to_string(),
998 model: "gpt-4".to_string(),
999 api_key: "test-key".to_string(),
1000 },
1001 crate::RequestTemplateInput {
1002 custom_id: None,
1003 endpoint: "https://api.example.com".to_string(),
1004 method: "POST".to_string(),
1005 path: "/v1/test".to_string(),
1006 body: r#"{"prompt":"test2"}"#.to_string(),
1007 model: "gpt-4".to_string(),
1008 api_key: "test-key".to_string(),
1009 },
1010 crate::RequestTemplateInput {
1011 custom_id: None,
1012 endpoint: "https://api.example.com".to_string(),
1013 method: "POST".to_string(),
1014 path: "/v1/test".to_string(),
1015 body: r#"{"prompt":"test3"}"#.to_string(),
1016 model: "gpt-4".to_string(),
1017 api_key: "test-key".to_string(),
1018 },
1019 crate::RequestTemplateInput {
1020 custom_id: None,
1021 endpoint: "https://api.example.com".to_string(),
1022 method: "POST".to_string(),
1023 path: "/v1/test".to_string(),
1024 body: r#"{"prompt":"test4"}"#.to_string(),
1025 model: "gpt-4".to_string(),
1026 api_key: "test-key".to_string(),
1027 },
1028 crate::RequestTemplateInput {
1029 custom_id: None,
1030 endpoint: "https://api.example.com".to_string(),
1031 method: "POST".to_string(),
1032 path: "/v1/test".to_string(),
1033 body: r#"{"prompt":"test5"}"#.to_string(),
1034 model: "gpt-4".to_string(),
1035 api_key: "test-key".to_string(),
1036 },
1037 ],
1038 )
1039 .await
1040 .expect("Failed to create file");
1041
1042 let batch = manager
1043 .create_batch(crate::batch::BatchInput {
1044 file_id,
1045 endpoint: "/v1/chat/completions".to_string(),
1046 completion_window: "24h".to_string(),
1047 metadata: None,
1048 created_by: None,
1049 })
1050 .await
1051 .expect("Failed to create batch");
1052
1053 let shutdown_token = tokio_util::sync::CancellationToken::new();
1055 manager
1056 .clone()
1057 .run(shutdown_token.clone())
1058 .expect("Failed to start daemon");
1059
1060 let start = tokio::time::Instant::now();
1062 let timeout = Duration::from_secs(2);
1063 let mut reached_limit = false;
1064
1065 while start.elapsed() < timeout {
1066 let in_flight = http_client.in_flight_count();
1067 if in_flight == 2 {
1068 reached_limit = true;
1069 break;
1070 }
1071 tokio::time::sleep(Duration::from_millis(10)).await;
1072 }
1073
1074 assert!(
1075 reached_limit,
1076 "Expected exactly 2 requests in-flight, got {}",
1077 http_client.in_flight_count()
1078 );
1079
1080 tokio::time::sleep(Duration::from_millis(100)).await;
1082 assert_eq!(
1083 http_client.in_flight_count(),
1084 2,
1085 "Concurrency limit violated: more than 2 requests in-flight"
1086 );
1087
1088 trigger1.send(()).unwrap();
1090
1091 let start = tokio::time::Instant::now();
1093 let timeout = Duration::from_secs(2);
1094 let mut third_started = false;
1095
1096 while start.elapsed() < timeout {
1097 if http_client.call_count() >= 3 {
1098 third_started = true;
1099 break;
1100 }
1101 tokio::time::sleep(Duration::from_millis(10)).await;
1102 }
1103
1104 assert!(
1105 third_started,
1106 "Third request should have started after first completed"
1107 );
1108
1109 assert_eq!(
1111 http_client.in_flight_count(),
1112 2,
1113 "Should maintain concurrency limit of 2"
1114 );
1115
1116 trigger2.send(()).unwrap();
1118 trigger3.send(()).unwrap();
1119 trigger4.send(()).unwrap();
1120 trigger5.send(()).unwrap();
1121
1122 let start = tokio::time::Instant::now();
1124 let timeout = Duration::from_secs(5);
1125 let mut all_completed = false;
1126
1127 while start.elapsed() < timeout {
1128 let status = manager
1129 .get_batch_status(batch.id)
1130 .await
1131 .expect("Failed to get batch status");
1132
1133 if status.completed_requests == 5 {
1134 all_completed = true;
1135 break;
1136 }
1137 tokio::time::sleep(Duration::from_millis(50)).await;
1138 }
1139
1140 shutdown_token.cancel();
1142
1143 assert!(all_completed, "All 5 requests should have completed");
1144
1145 assert_eq!(http_client.call_count(), 5);
1147 }
1148
1149 #[sqlx::test]
1150 async fn test_daemon_retries_failed_requests(pool: sqlx::PgPool) {
1151 let http_client = Arc::new(MockHttpClient::new());
1153
1154 http_client.add_response(
1156 "POST /v1/test",
1157 Ok(HttpResponse {
1158 status: 500,
1159 body: r#"{"error":"internal error"}"#.to_string(),
1160 }),
1161 );
1162
1163 http_client.add_response(
1165 "POST /v1/test",
1166 Ok(HttpResponse {
1167 status: 503,
1168 body: r#"{"error":"service unavailable"}"#.to_string(),
1169 }),
1170 );
1171
1172 http_client.add_response(
1174 "POST /v1/test",
1175 Ok(HttpResponse {
1176 status: 200,
1177 body: r#"{"result":"success after retries"}"#.to_string(),
1178 }),
1179 );
1180
1181 let config = DaemonConfig {
1183 claim_batch_size: 10,
1184 claim_interval_ms: 10,
1185 default_model_concurrency: 10,
1186 model_concurrency_limits: Arc::new(dashmap::DashMap::new()),
1187 model_escalations: Arc::new(dashmap::DashMap::new()),
1188 max_retries: Some(5),
1189 stop_before_deadline_ms: None,
1190 backoff_ms: 10, backoff_factor: 2,
1192 max_backoff_ms: 100,
1193 timeout_ms: 5000,
1194 status_log_interval_ms: None,
1195 heartbeat_interval_ms: 10000,
1196 should_retry: Arc::new(default_should_retry),
1197 claim_timeout_ms: 60000,
1198 processing_timeout_ms: 600000,
1199 unclaim_batch_size: 100,
1200 batch_metadata_fields: vec![],
1201 cancellation_poll_interval_ms: 100, };
1203
1204 let manager = Arc::new(
1205 PostgresRequestManager::with_client(
1206 TestDbPools::new(pool.clone()).await.unwrap(),
1207 http_client.clone(),
1208 )
1209 .with_config(config),
1210 );
1211
1212 let file_id = manager
1214 .create_file(
1215 "test-file".to_string(),
1216 Some("Test retry logic".to_string()),
1217 vec![crate::RequestTemplateInput {
1218 custom_id: None,
1219 endpoint: "https://api.example.com".to_string(),
1220 method: "POST".to_string(),
1221 path: "/v1/test".to_string(),
1222 body: r#"{"prompt":"test"}"#.to_string(),
1223 model: "test-model".to_string(),
1224 api_key: "test-key".to_string(),
1225 }],
1226 )
1227 .await
1228 .expect("Failed to create file");
1229
1230 let batch = manager
1231 .create_batch(crate::batch::BatchInput {
1232 file_id,
1233 endpoint: "/v1/chat/completions".to_string(),
1234 completion_window: "24h".to_string(),
1235 metadata: None,
1236 created_by: None,
1237 })
1238 .await
1239 .expect("Failed to create batch");
1240
1241 let requests = manager
1242 .get_batch_requests(batch.id)
1243 .await
1244 .expect("Failed to get batch requests");
1245 assert_eq!(requests.len(), 1);
1246 let request_id = requests[0].id();
1247
1248 let shutdown_token = tokio_util::sync::CancellationToken::new();
1250 manager
1251 .clone()
1252 .run(shutdown_token.clone())
1253 .expect("Failed to start daemon");
1254
1255 let start = tokio::time::Instant::now();
1257 let timeout = Duration::from_secs(5);
1258 let mut completed = false;
1259
1260 while start.elapsed() < timeout {
1261 let results = manager
1262 .get_requests(vec![request_id])
1263 .await
1264 .expect("Failed to get request");
1265
1266 if let Some(Ok(any_request)) = results.first()
1267 && let crate::AnyRequest::Completed(req) = any_request
1268 {
1269 assert_eq!(req.state.response_status, 200);
1271 assert_eq!(
1272 req.state.response_body,
1273 r#"{"result":"success after retries"}"#
1274 );
1275 completed = true;
1276 break;
1277 }
1278
1279 tokio::time::sleep(Duration::from_millis(50)).await;
1280 }
1281
1282 shutdown_token.cancel();
1284
1285 assert!(completed, "Request should have completed after retries");
1286
1287 assert_eq!(
1289 http_client.call_count(),
1290 3,
1291 "Expected 3 HTTP calls (2 failed attempts + 1 success)"
1292 );
1293 }
1294
1295 #[sqlx::test]
1296 async fn test_daemon_dynamically_updates_concurrency_limits(pool: sqlx::PgPool) {
1297 let http_client = Arc::new(MockHttpClient::new());
1299
1300 let mut triggers = vec![];
1302 for i in 1..=10 {
1303 let trigger = http_client.add_response_with_trigger(
1304 "POST /v1/test",
1305 Ok(HttpResponse {
1306 status: 200,
1307 body: format!(r#"{{"result":"{}"}}"#, i),
1308 }),
1309 );
1310 triggers.push(trigger);
1311 }
1312
1313 let model_concurrency_limits = Arc::new(dashmap::DashMap::new());
1315 model_concurrency_limits.insert("gpt-4".to_string(), 2);
1316
1317 let config = DaemonConfig {
1318 claim_batch_size: 10,
1319 claim_interval_ms: 10,
1320 default_model_concurrency: 10,
1321 model_concurrency_limits: model_concurrency_limits.clone(),
1322 model_escalations: Arc::new(dashmap::DashMap::new()),
1323 max_retries: Some(3),
1324 stop_before_deadline_ms: None,
1325 backoff_ms: 100,
1326 backoff_factor: 2,
1327 max_backoff_ms: 1000,
1328 timeout_ms: 5000,
1329 status_log_interval_ms: None,
1330 heartbeat_interval_ms: 10000,
1331 should_retry: Arc::new(default_should_retry),
1332 claim_timeout_ms: 60000,
1333 processing_timeout_ms: 600000,
1334 unclaim_batch_size: 100,
1335 batch_metadata_fields: vec![],
1336 cancellation_poll_interval_ms: 100, };
1338
1339 let manager = Arc::new(
1340 PostgresRequestManager::with_client(
1341 TestDbPools::new(pool.clone()).await.unwrap(),
1342 http_client.clone(),
1343 )
1344 .with_config(config),
1345 );
1346
1347 let templates: Vec<_> = (1..=10)
1349 .map(|i| crate::RequestTemplateInput {
1350 custom_id: None,
1351 endpoint: "https://api.example.com".to_string(),
1352 method: "POST".to_string(),
1353 path: "/v1/test".to_string(),
1354 body: format!(r#"{{"prompt":"test{}"}}"#, i),
1355 model: "gpt-4".to_string(),
1356 api_key: "test-key".to_string(),
1357 })
1358 .collect();
1359
1360 let file_id = manager
1361 .create_file(
1362 "test-file".to_string(),
1363 Some("Test dynamic limits".to_string()),
1364 templates,
1365 )
1366 .await
1367 .expect("Failed to create file");
1368
1369 let batch = manager
1370 .create_batch(crate::batch::BatchInput {
1371 file_id,
1372 endpoint: "/v1/chat/completions".to_string(),
1373 completion_window: "24h".to_string(),
1374 metadata: None,
1375 created_by: None,
1376 })
1377 .await
1378 .expect("Failed to create batch");
1379
1380 let shutdown_token = tokio_util::sync::CancellationToken::new();
1382 manager
1383 .clone()
1384 .run(shutdown_token.clone())
1385 .expect("Failed to start daemon");
1386
1387 let start = tokio::time::Instant::now();
1389 let timeout = Duration::from_secs(2);
1390 let mut reached_initial_limit = false;
1391
1392 while start.elapsed() < timeout {
1393 let in_flight = http_client.in_flight_count();
1394 if in_flight == 2 {
1395 reached_initial_limit = true;
1396 break;
1397 }
1398 tokio::time::sleep(Duration::from_millis(10)).await;
1399 }
1400
1401 assert!(
1402 reached_initial_limit,
1403 "Expected exactly 2 requests in-flight with initial limit"
1404 );
1405
1406 model_concurrency_limits.insert("gpt-4".to_string(), 5);
1408
1409 tokio::time::sleep(Duration::from_millis(100)).await;
1411
1412 triggers.remove(0).send(()).unwrap();
1414 tokio::time::sleep(Duration::from_millis(50)).await;
1415
1416 let start = tokio::time::Instant::now();
1418 let timeout = Duration::from_secs(2);
1419 let mut reached_new_limit = false;
1420
1421 while start.elapsed() < timeout {
1422 let in_flight = http_client.in_flight_count();
1423 if in_flight >= 4 {
1424 reached_new_limit = true;
1426 break;
1427 }
1428 tokio::time::sleep(Duration::from_millis(10)).await;
1429 }
1430
1431 assert!(
1432 reached_new_limit,
1433 "Expected more requests in-flight after limit increase, got {}",
1434 http_client.in_flight_count()
1435 );
1436
1437 model_concurrency_limits.insert("gpt-4".to_string(), 3);
1439
1440 for trigger in triggers {
1442 trigger.send(()).unwrap();
1443 }
1444
1445 let start = tokio::time::Instant::now();
1447 let timeout = Duration::from_secs(5);
1448 let mut all_completed = false;
1449
1450 while start.elapsed() < timeout {
1451 let status = manager
1452 .get_batch_status(batch.id)
1453 .await
1454 .expect("Failed to get batch status");
1455
1456 if status.completed_requests == 10 {
1457 all_completed = true;
1458 break;
1459 }
1460 tokio::time::sleep(Duration::from_millis(50)).await;
1461 }
1462
1463 shutdown_token.cancel();
1465
1466 assert!(all_completed, "All 10 requests should have completed");
1467 assert_eq!(http_client.call_count(), 10);
1468 }
1469
1470 #[sqlx::test]
1471 async fn test_deadline_aware_retry_stops_before_deadline(pool: sqlx::PgPool) {
1472 let http_client = Arc::new(MockHttpClient::new());
1474
1475 for _ in 0..20 {
1477 http_client.add_response(
1478 "POST /v1/test",
1479 Ok(HttpResponse {
1480 status: 500,
1481 body: r#"{"error":"server error"}"#.to_string(),
1482 }),
1483 );
1484 }
1485
1486 let config = DaemonConfig {
1488 claim_batch_size: 10,
1489 claim_interval_ms: 10,
1490 default_model_concurrency: 10,
1491 model_concurrency_limits: Arc::new(dashmap::DashMap::new()),
1492 model_escalations: Arc::new(dashmap::DashMap::new()),
1493 max_retries: Some(10_000),
1494 stop_before_deadline_ms: Some(500), backoff_ms: 50,
1496 backoff_factor: 2,
1497 max_backoff_ms: 200,
1498 timeout_ms: 5000,
1499 status_log_interval_ms: None,
1500 heartbeat_interval_ms: 10000,
1501 should_retry: Arc::new(default_should_retry),
1502 claim_timeout_ms: 60000,
1503 processing_timeout_ms: 600000,
1504 unclaim_batch_size: 100,
1505 batch_metadata_fields: vec![],
1506 cancellation_poll_interval_ms: 100,
1507 };
1508
1509 let manager = Arc::new(
1510 PostgresRequestManager::with_client(
1511 TestDbPools::new(pool.clone()).await.unwrap(),
1512 http_client.clone(),
1513 )
1514 .with_config(config),
1515 );
1516
1517 let file_id = manager
1519 .create_file(
1520 "test-file".to_string(),
1521 Some("Test deadline cutoff".to_string()),
1522 vec![crate::RequestTemplateInput {
1523 custom_id: None,
1524 endpoint: "https://api.example.com".to_string(),
1525 method: "POST".to_string(),
1526 path: "/v1/test".to_string(),
1527 body: r#"{"prompt":"test"}"#.to_string(),
1528 model: "test-model".to_string(),
1529 api_key: "test-key".to_string(),
1530 }],
1531 )
1532 .await
1533 .expect("Failed to create file");
1534
1535 let batch = manager
1536 .create_batch(crate::batch::BatchInput {
1537 file_id,
1538 endpoint: "/v1/chat/completions".to_string(),
1539 completion_window: "2s".to_string(), metadata: None,
1541 created_by: None,
1542 })
1543 .await
1544 .expect("Failed to create batch");
1545
1546 let requests = manager
1547 .get_batch_requests(batch.id)
1548 .await
1549 .expect("Failed to get batch requests");
1550 let request_id = requests[0].id();
1551
1552 let shutdown_token = tokio_util::sync::CancellationToken::new();
1554 manager
1555 .clone()
1556 .run(shutdown_token.clone())
1557 .expect("Failed to start daemon");
1558
1559 tokio::time::sleep(Duration::from_secs(3)).await;
1561
1562 let results = manager
1564 .get_requests(vec![request_id])
1565 .await
1566 .expect("Failed to get request");
1567
1568 shutdown_token.cancel();
1569
1570 if let Some(Ok(crate::AnyRequest::Failed(failed))) = results.first() {
1571 let retry_count = failed.state.retry_attempt;
1590 let call_count = http_client.call_count();
1591
1592 assert!(
1596 (4..=9).contains(&retry_count),
1597 "Expected 4-9 retry attempts based on deadline and backoff calculation, got {}",
1598 retry_count
1599 );
1600
1601 assert_eq!(
1603 call_count,
1604 (retry_count + 1) as usize,
1605 "Expected call count to match retry attempts + 1 initial attempt, got {} calls for {} retry attempts",
1606 call_count,
1607 retry_count
1608 );
1609
1610 assert!(
1612 !failed.state.reason.to_error_message().is_empty(),
1613 "Expected failed request to have failure reason"
1614 );
1615 } else {
1616 panic!(
1617 "Expected request to be in Failed state, got {:?}",
1618 results.first()
1619 );
1620 }
1621 }
1622
1623 #[sqlx::test]
1624 async fn test_retry_stops_at_deadline_when_no_limits_set(pool: sqlx::PgPool) {
1625 let http_client = Arc::new(MockHttpClient::new());
1628
1629 for _ in 0..20 {
1631 http_client.add_response(
1632 "POST /v1/test",
1633 Ok(HttpResponse {
1634 status: 500,
1635 body: r#"{"error":"server error"}"#.to_string(),
1636 }),
1637 );
1638 }
1639
1640 let config = DaemonConfig {
1642 claim_batch_size: 10,
1643 claim_interval_ms: 10,
1644 default_model_concurrency: 10,
1645 model_concurrency_limits: Arc::new(dashmap::DashMap::new()),
1646 model_escalations: Arc::new(dashmap::DashMap::new()),
1647 max_retries: None, stop_before_deadline_ms: None, backoff_ms: 50,
1650 backoff_factor: 2,
1651 max_backoff_ms: 200,
1652 timeout_ms: 5000,
1653 status_log_interval_ms: None,
1654 heartbeat_interval_ms: 10000,
1655 should_retry: Arc::new(default_should_retry),
1656 claim_timeout_ms: 60000,
1657 processing_timeout_ms: 600000,
1658 unclaim_batch_size: 100,
1659 batch_metadata_fields: vec![],
1660 cancellation_poll_interval_ms: 100,
1661 };
1662
1663 let manager = Arc::new(
1664 PostgresRequestManager::with_client(
1665 TestDbPools::new(pool.clone()).await.unwrap(),
1666 http_client.clone(),
1667 )
1668 .with_config(config),
1669 );
1670
1671 let file_id = manager
1673 .create_file(
1674 "test-file".to_string(),
1675 Some("Test no limits retry".to_string()),
1676 vec![crate::RequestTemplateInput {
1677 custom_id: None,
1678 endpoint: "https://api.example.com".to_string(),
1679 method: "POST".to_string(),
1680 path: "/v1/test".to_string(),
1681 body: r#"{"prompt":"test"}"#.to_string(),
1682 model: "test-model".to_string(),
1683 api_key: "test-key".to_string(),
1684 }],
1685 )
1686 .await
1687 .expect("Failed to create file");
1688
1689 let batch = manager
1690 .create_batch(crate::batch::BatchInput {
1691 file_id,
1692 endpoint: "/v1/chat/completions".to_string(),
1693 completion_window: "2s".to_string(),
1694 metadata: None,
1695 created_by: None,
1696 })
1697 .await
1698 .expect("Failed to create batch");
1699
1700 let requests = manager
1701 .get_batch_requests(batch.id)
1702 .await
1703 .expect("Failed to get batch requests");
1704 let request_id = requests[0].id();
1705
1706 let shutdown_token = tokio_util::sync::CancellationToken::new();
1708 manager
1709 .clone()
1710 .run(shutdown_token.clone())
1711 .expect("Failed to start daemon");
1712
1713 tokio::time::sleep(Duration::from_secs(3)).await;
1715
1716 let results = manager
1718 .get_requests(vec![request_id])
1719 .await
1720 .expect("Failed to get request");
1721
1722 shutdown_token.cancel();
1723
1724 if let Some(Ok(crate::AnyRequest::Failed(failed))) = results.first() {
1725 let retry_count = failed.state.retry_attempt;
1748 let call_count = http_client.call_count();
1749
1750 assert!(
1755 (6..12).contains(&retry_count),
1756 "Expected 6-12 retry attempts (should retry until deadline with no buffer), got {}",
1757 retry_count
1758 );
1759
1760 assert_eq!(
1762 call_count,
1763 (retry_count + 1) as usize,
1764 "Expected call count to match retry attempts + 1 initial attempt, got {} calls for {} retry attempts",
1765 call_count,
1766 retry_count
1767 );
1768
1769 assert!(
1771 !failed.state.reason.to_error_message().is_empty(),
1772 "Expected failed request to have failure reason"
1773 );
1774 } else {
1775 panic!(
1776 "Expected request to be in Failed state, got {:?}",
1777 results.first()
1778 );
1779 }
1780 }
1781
1782 #[sqlx::test]
1783 #[test_log::test]
1784 async fn test_batch_metadata_headers_passed_through(pool: sqlx::PgPool) {
1785 let http_client = crate::http::MockHttpClient::new();
1786 http_client.add_response(
1787 "POST /v1/chat/completions",
1788 Ok(crate::http::HttpResponse {
1789 status: 200,
1790 body: r#"{"id":"chatcmpl-123","choices":[{"message":{"content":"test"}}]}"#
1791 .to_string(),
1792 }),
1793 );
1794
1795 let config = DaemonConfig {
1796 claim_batch_size: 10,
1797 claim_interval_ms: 10,
1798 default_model_concurrency: 10,
1799 model_concurrency_limits: Arc::new(dashmap::DashMap::new()),
1800 model_escalations: Arc::new(dashmap::DashMap::new()),
1801 max_retries: Some(3),
1802 stop_before_deadline_ms: None,
1803 backoff_ms: 100,
1804 backoff_factor: 2,
1805 max_backoff_ms: 1000,
1806 timeout_ms: 5000,
1807 status_log_interval_ms: None,
1808 heartbeat_interval_ms: 10000,
1809 should_retry: Arc::new(default_should_retry),
1810 claim_timeout_ms: 60000,
1811 processing_timeout_ms: 600000,
1812 unclaim_batch_size: 100,
1813 batch_metadata_fields: vec![
1814 "id".to_string(),
1815 "endpoint".to_string(),
1816 "created_at".to_string(),
1817 "completion_window".to_string(),
1818 ],
1819 cancellation_poll_interval_ms: 100,
1820 };
1821
1822 let manager = Arc::new(
1823 PostgresRequestManager::with_client(
1824 TestDbPools::new(pool.clone()).await.unwrap(),
1825 Arc::new(http_client.clone()),
1826 )
1827 .with_config(config),
1828 );
1829
1830 let file_id = manager
1832 .create_file(
1833 "test-file".to_string(),
1834 Some("Test batch metadata".to_string()),
1835 vec![crate::RequestTemplateInput {
1836 custom_id: None,
1837 endpoint: "https://api.example.com".to_string(),
1838 method: "POST".to_string(),
1839 path: "/v1/chat/completions".to_string(),
1840 body: r#"{"prompt":"test"}"#.to_string(),
1841 model: "test-model".to_string(),
1842 api_key: "test-key".to_string(),
1843 }],
1844 )
1845 .await
1846 .expect("Failed to create file");
1847
1848 let batch = manager
1849 .create_batch(crate::batch::BatchInput {
1850 file_id,
1851 endpoint: "/v1/chat/completions".to_string(),
1852 completion_window: "24h".to_string(),
1853 metadata: None,
1854 created_by: Some("test-user".to_string()),
1855 })
1856 .await
1857 .expect("Failed to create batch");
1858
1859 let requests = manager
1860 .get_batch_requests(batch.id)
1861 .await
1862 .expect("Failed to get batch requests");
1863 let request_id = requests[0].id();
1864
1865 let shutdown_token = tokio_util::sync::CancellationToken::new();
1867 manager
1868 .clone()
1869 .run(shutdown_token.clone())
1870 .expect("Failed to start daemon");
1871
1872 tokio::time::sleep(Duration::from_millis(500)).await;
1874
1875 shutdown_token.cancel();
1876
1877 tokio::time::sleep(Duration::from_millis(200)).await;
1879
1880 let results = manager
1882 .get_requests(vec![request_id])
1883 .await
1884 .expect("Failed to get request");
1885
1886 assert_eq!(results.len(), 1);
1887 assert!(
1888 matches!(results[0], Ok(crate::AnyRequest::Completed(_))),
1889 "Expected request to be completed"
1890 );
1891
1892 let calls = http_client.get_calls();
1894 assert_eq!(calls.len(), 1, "Expected exactly one HTTP call");
1895
1896 let call = &calls[0];
1897 assert_eq!(
1898 call.batch_metadata.len(),
1899 4,
1900 "Expected 4 batch metadata fields"
1901 );
1902
1903 assert!(
1905 call.batch_metadata.contains_key("id"),
1906 "Expected batch id in metadata"
1907 );
1908 assert!(
1909 call.batch_metadata.contains_key("endpoint"),
1910 "Expected batch endpoint in metadata"
1911 );
1912 assert!(
1913 call.batch_metadata.contains_key("created_at"),
1914 "Expected batch created_at in metadata"
1915 );
1916 assert!(
1917 call.batch_metadata.contains_key("completion_window"),
1918 "Expected batch completion_window in metadata"
1919 );
1920
1921 assert_eq!(
1923 call.batch_metadata.get("endpoint"),
1924 Some(&"/v1/chat/completions".to_string()),
1925 "Batch endpoint should match"
1926 );
1927 assert_eq!(
1928 call.batch_metadata.get("completion_window"),
1929 Some(&"24h".to_string()),
1930 "Completion window should match"
1931 );
1932 }
1933
1934 #[sqlx::test]
1935 #[test_log::test]
1936 async fn test_batch_metadata_extracts_fields_from_json_metadata(pool: sqlx::PgPool) {
1937 let http_client = crate::http::MockHttpClient::new();
1938 http_client.add_response(
1939 "POST /v1/chat/completions",
1940 Ok(crate::http::HttpResponse {
1941 status: 200,
1942 body: r#"{"id":"chatcmpl-123","choices":[{"message":{"content":"test"}}]}"#
1943 .to_string(),
1944 }),
1945 );
1946
1947 let config = DaemonConfig {
1950 claim_batch_size: 10,
1951 claim_interval_ms: 10,
1952 default_model_concurrency: 10,
1953 model_concurrency_limits: Arc::new(dashmap::DashMap::new()),
1954 model_escalations: Arc::new(dashmap::DashMap::new()),
1955 max_retries: Some(3),
1956 stop_before_deadline_ms: None,
1957 backoff_ms: 100,
1958 backoff_factor: 2,
1959 max_backoff_ms: 1000,
1960 timeout_ms: 5000,
1961 status_log_interval_ms: None,
1962 heartbeat_interval_ms: 10000,
1963 should_retry: Arc::new(default_should_retry),
1964 claim_timeout_ms: 60000,
1965 processing_timeout_ms: 600000,
1966 unclaim_batch_size: 100,
1967 batch_metadata_fields: vec![
1968 "id".to_string(),
1969 "endpoint".to_string(),
1970 "completion_window".to_string(),
1971 "request_source".to_string(), ],
1973 cancellation_poll_interval_ms: 100,
1974 };
1975
1976 let manager = Arc::new(
1977 PostgresRequestManager::with_client(
1978 TestDbPools::new(pool.clone()).await.unwrap(),
1979 Arc::new(http_client.clone()),
1980 )
1981 .with_config(config),
1982 );
1983
1984 let file_id = manager
1986 .create_file(
1987 "test-file".to_string(),
1988 Some("Test metadata JSON extraction".to_string()),
1989 vec![crate::RequestTemplateInput {
1990 custom_id: None,
1991 endpoint: "https://api.example.com".to_string(),
1992 method: "POST".to_string(),
1993 path: "/v1/chat/completions".to_string(),
1994 body: r#"{"prompt":"test"}"#.to_string(),
1995 model: "test-model".to_string(),
1996 api_key: "test-key".to_string(),
1997 }],
1998 )
1999 .await
2000 .expect("Failed to create file");
2001
2002 let batch = manager
2003 .create_batch(crate::batch::BatchInput {
2004 file_id,
2005 endpoint: "/v1/chat/completions".to_string(),
2006 completion_window: "24h".to_string(),
2007 metadata: Some(serde_json::json!({
2008 "request_source": "api",
2009 "created_by": "user-123"
2010 })),
2011 created_by: Some("test-user".to_string()),
2012 })
2013 .await
2014 .expect("Failed to create batch");
2015
2016 let requests = manager
2017 .get_batch_requests(batch.id)
2018 .await
2019 .expect("Failed to get batch requests");
2020 let request_id = requests[0].id();
2021
2022 let shutdown_token = tokio_util::sync::CancellationToken::new();
2024 manager
2025 .clone()
2026 .run(shutdown_token.clone())
2027 .expect("Failed to start daemon");
2028
2029 tokio::time::sleep(Duration::from_millis(500)).await;
2031
2032 shutdown_token.cancel();
2033
2034 tokio::time::sleep(Duration::from_millis(200)).await;
2036
2037 let results = manager
2039 .get_requests(vec![request_id])
2040 .await
2041 .expect("Failed to get request");
2042
2043 assert_eq!(results.len(), 1);
2044 assert!(
2045 matches!(results[0], Ok(crate::AnyRequest::Completed(_))),
2046 "Expected request to be completed"
2047 );
2048
2049 let calls = http_client.get_calls();
2051 assert_eq!(calls.len(), 1, "Expected exactly one HTTP call");
2052
2053 let call = &calls[0];
2054 assert_eq!(
2055 call.batch_metadata.len(),
2056 4,
2057 "Expected 4 batch metadata fields (id, endpoint, completion_window, request_source)"
2058 );
2059
2060 assert!(
2062 call.batch_metadata.contains_key("id"),
2063 "Expected batch id in metadata"
2064 );
2065 assert_eq!(
2066 call.batch_metadata.get("endpoint"),
2067 Some(&"/v1/chat/completions".to_string()),
2068 "Batch endpoint should match"
2069 );
2070
2071 assert_eq!(
2073 call.batch_metadata.get("request_source"),
2074 Some(&"api".to_string()),
2075 "request_source should be extracted from metadata JSON"
2076 );
2077 }
2078}