Skip to main content

fusillade/daemon/
mod.rs

1//! Daemon for processing batched requests with per-model concurrency control.
2use std::collections::HashMap;
3use std::sync::Arc;
4use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
5use std::time::Duration;
6
7use metrics::{counter, gauge, histogram};
8use tokio::sync::{RwLock, Semaphore};
9use tokio::task::JoinSet;
10
11use crate::FusilladeError;
12use crate::batch::BatchId;
13use crate::error::Result;
14use crate::http::{HttpClient, HttpResponse};
15use crate::manager::{DaemonStorage, Storage};
16use crate::request::{DaemonId, RequestCompletionResult};
17
18pub mod transitions;
19pub mod types;
20
21pub use types::{
22    AnyDaemonRecord, DaemonData, DaemonRecord, DaemonState, DaemonStats, DaemonStatus, Dead,
23    Initializing, Running,
24};
25
26/// Predicate function to determine if a response should be retried.
27///
28/// Takes an HTTP response and returns true if the request should be retried.
29pub type ShouldRetryFn = Arc<dyn Fn(&HttpResponse) -> bool + Send + Sync>;
30
31/// Semaphore entry tracking both the semaphore and its configured limit.
32type SemaphoreEntry = (Arc<Semaphore>, usize);
33
34/// Default retry predicate: retry on server errors (5xx), rate limits (429), and timeouts (408).
35pub fn default_should_retry(response: &HttpResponse) -> bool {
36    response.status >= 500 || response.status == 429 || response.status == 408
37}
38
39/// Default function for creating the should_retry Arc
40fn default_should_retry_fn() -> ShouldRetryFn {
41    Arc::new(default_should_retry)
42}
43
44/// Default model escalations (empty map)
45fn default_model_escalations() -> Arc<dashmap::DashMap<String, ModelEscalationConfig>> {
46    Arc::new(dashmap::DashMap::new())
47}
48
49/// Default escalation threshold (15 minutes)
50/// This should be greater than the processing timeout (10 minutes) to allow
51/// a processing request to fall back to pending before escalation kicks in.
52fn default_escalation_threshold_seconds() -> i64 {
53    900
54}
55
56/// Model-based escalation configuration for routing requests to a different model
57/// at claim time when approaching SLA deadline.
58///
59/// When a request is claimed with less than `escalation_threshold_seconds` remaining
60/// before batch expiry, it will be routed to the `escalation_model` instead of the
61/// original model. The batch API key automatically has access to escalation models
62/// in the onwards routing cache (no separate API key needed).
63#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
64pub struct ModelEscalationConfig {
65    /// The model to escalate to (e.g., "o1-preview" for requests using "gpt-4")
66    pub escalation_model: String,
67
68    /// Time threshold in seconds - escalate when time remaining before batch expiry
69    /// is less than this value. Default: 900 (15 minutes)
70    #[serde(default = "default_escalation_threshold_seconds")]
71    pub escalation_threshold_seconds: i64,
72}
73
74/// Configuration for the daemon.
75#[derive(Clone, serde::Serialize, serde::Deserialize)]
76pub struct DaemonConfig {
77    /// Maximum number of requests to claim in each iteration
78    pub claim_batch_size: usize,
79
80    /// Default concurrency limit per model
81    pub default_model_concurrency: usize,
82
83    /// Per-model concurrency overrides (shared, can be updated dynamically)
84    pub model_concurrency_limits: Arc<dashmap::DashMap<String, usize>>,
85
86    /// Per-model escalation configurations for SLA-based model switching
87    /// Maps model name -> escalation config (e.g., "gpt-4" -> "o1-preview")
88    /// When a request is escalated, it's routed to the escalation_model by the control layer
89    #[serde(skip, default = "default_model_escalations")]
90    pub model_escalations: Arc<dashmap::DashMap<String, ModelEscalationConfig>>,
91
92    /// How long to sleep between claim iterations
93    pub claim_interval_ms: u64,
94
95    /// Maximum number of retry attempts before giving up.
96    pub max_retries: Option<u32>,
97
98    /// Stop retrying this many milliseconds before the batch expires.
99    /// Positive values stop before the deadline (safety buffer).
100    /// Negative values allow retrying after the deadline.
101    /// If None, retries are not deadline-aware.
102    pub stop_before_deadline_ms: Option<i64>,
103
104    /// Base backoff duration in milliseconds (will be exponentially increased)
105    pub backoff_ms: u64,
106
107    /// Factor by which the backoff_ms is increased with each retry
108    pub backoff_factor: u64,
109
110    /// Maximum backoff time in milliseconds
111    pub max_backoff_ms: u64,
112
113    /// Timeout for each individual request attempt in milliseconds
114    pub timeout_ms: u64,
115
116    /// Interval for logging daemon status (requests in flight) in milliseconds
117    /// Set to None to disable periodic status logging
118    pub status_log_interval_ms: Option<u64>,
119
120    /// Interval for sending heartbeats to update daemon status in database (milliseconds)
121    pub heartbeat_interval_ms: u64,
122
123    /// Predicate function to determine if a response should be retried.
124    /// Defaults to retrying 5xx, 429, and 408 status codes.
125    #[serde(skip, default = "default_should_retry_fn")]
126    pub should_retry: ShouldRetryFn,
127
128    /// Maximum time a request can stay in "claimed" state before being unclaimed
129    /// and returned to pending (milliseconds). This handles daemon crashes.
130    pub claim_timeout_ms: u64,
131
132    /// Maximum time a request can stay in "processing" state before being unclaimed
133    /// and returned to pending (milliseconds). This handles daemon crashes during execution.
134    pub processing_timeout_ms: u64,
135
136    /// Maximum number of stale requests to unclaim in a single poll cycle.
137    /// Limits database load when many requests become stale simultaneously (e.g., daemon crash).
138    pub unclaim_batch_size: usize,
139
140    /// Interval for polling database to check for cancelled batches (milliseconds)
141    /// Determines how quickly in-flight requests are aborted when their batch is cancelled
142    pub cancellation_poll_interval_ms: u64,
143
144    /// Batch table column names to include as request headers.
145    /// These values are sent as `x-fusillade-batch-{column}` headers with each request.
146    /// Example: ["id", "created_by", "endpoint"] produces headers like:
147    ///   - x-fusillade-batch-id
148    ///   - x-fusillade-batch-created-by
149    ///   - x-fusillade-batch-endpoint
150    #[serde(default = "default_batch_metadata_fields")]
151    pub batch_metadata_fields: Vec<String>,
152}
153
154fn default_batch_metadata_fields() -> Vec<String> {
155    vec![
156        "id".to_string(),
157        "endpoint".to_string(),
158        "created_at".to_string(),
159        "completion_window".to_string(),
160    ]
161}
162
163impl Default for DaemonConfig {
164    fn default() -> Self {
165        Self {
166            claim_batch_size: 100,
167            default_model_concurrency: 10,
168            model_concurrency_limits: Arc::new(dashmap::DashMap::new()),
169            model_escalations: default_model_escalations(),
170            claim_interval_ms: 1000,
171            max_retries: Some(1000),
172            stop_before_deadline_ms: Some(900_000),
173            backoff_ms: 1000,
174            backoff_factor: 2,
175            max_backoff_ms: 10000,
176            timeout_ms: 600000,
177            status_log_interval_ms: Some(2000), // Log every 2 seconds by default
178            heartbeat_interval_ms: 10000,       // Heartbeat every 10 seconds by default
179            should_retry: Arc::new(default_should_retry),
180            claim_timeout_ms: 60000,             // 1 minute
181            processing_timeout_ms: 600000,       // 10 minutes
182            unclaim_batch_size: 100,             // Unclaim up to 100 stale requests per poll
183            cancellation_poll_interval_ms: 5000, // Poll every 5 seconds by default
184            batch_metadata_fields: default_batch_metadata_fields(),
185        }
186    }
187}
188
189/// Daemon that processes batched requests.
190///
191/// The daemon continuously claims pending requests from storage, enforces
192/// per-model concurrency limits, and dispatches requests for execution.
193pub struct Daemon<S, H>
194where
195    S: Storage + DaemonStorage,
196    H: HttpClient,
197{
198    daemon_id: DaemonId,
199    storage: Arc<S>,
200    http_client: Arc<H>,
201    config: DaemonConfig,
202    semaphores: Arc<RwLock<HashMap<String, SemaphoreEntry>>>,
203    requests_in_flight: Arc<AtomicUsize>,
204    requests_processed: Arc<AtomicU64>,
205    requests_failed: Arc<AtomicU64>,
206    shutdown_token: tokio_util::sync::CancellationToken,
207    /// Map of batch_id -> cancellation token for batch-level cancellation
208    /// All requests in a batch share the same cancellation token
209    cancellation_tokens: Arc<dashmap::DashMap<BatchId, tokio_util::sync::CancellationToken>>,
210}
211
212impl<S, H> Daemon<S, H>
213where
214    S: Storage + DaemonStorage + 'static,
215    H: HttpClient + 'static,
216{
217    /// Create a new daemon.
218    pub fn new(
219        storage: Arc<S>,
220        http_client: Arc<H>,
221        config: DaemonConfig,
222        shutdown_token: tokio_util::sync::CancellationToken,
223    ) -> Self {
224        Self {
225            daemon_id: DaemonId::from(uuid::Uuid::new_v4()),
226            storage,
227            http_client,
228            config,
229            semaphores: Arc::new(RwLock::new(HashMap::new())),
230            requests_in_flight: Arc::new(AtomicUsize::new(0)),
231            requests_processed: Arc::new(AtomicU64::new(0)),
232            requests_failed: Arc::new(AtomicU64::new(0)),
233            shutdown_token,
234            cancellation_tokens: Arc::new(dashmap::DashMap::new()),
235        }
236    }
237
238    /// Get or create a semaphore for a model.
239    ///
240    /// Automatically adjusts the semaphore's permit count if the configured limit has changed.
241    /// For limit increases, adds permits. For decreases, forgets permits (as many as possible).
242    /// Note: When decreasing, we can only forget permits that aren't currently held, so the
243    /// effective limit may temporarily remain higher until requests complete.
244    async fn get_semaphore(&self, model: &str) -> Arc<Semaphore> {
245        let current_limit = self
246            .config
247            .model_concurrency_limits
248            .get(model)
249            .map(|entry| *entry.value())
250            .unwrap_or(self.config.default_model_concurrency);
251
252        let mut semaphores = self.semaphores.write().await;
253
254        let entry = semaphores
255            .entry(model.to_string())
256            .or_insert_with(|| (Arc::new(Semaphore::new(current_limit)), current_limit));
257
258        let (semaphore, stored_limit) = entry;
259
260        // Check if the limit has changed
261        if *stored_limit != current_limit {
262            if current_limit > *stored_limit {
263                // Limit increased - add permits
264                let delta = current_limit - *stored_limit;
265                semaphore.add_permits(delta);
266                tracing::info!(
267                    model = %model,
268                    old_limit = *stored_limit,
269                    new_limit = current_limit,
270                    added_permits = delta,
271                    "Increased model concurrency limit"
272                );
273                *stored_limit = current_limit;
274            } else {
275                // Limit decreased - forget permits (as many as we can)
276                let desired_delta = *stored_limit - current_limit;
277                let actual_forgotten = semaphore.forget_permits(desired_delta);
278
279                if actual_forgotten < desired_delta {
280                    tracing::warn!(
281                        model = %model,
282                        old_limit = *stored_limit,
283                        target_limit = current_limit,
284                        desired_to_forget = desired_delta,
285                        actually_forgot = actual_forgotten,
286                        held_permits = desired_delta - actual_forgotten,
287                        "Decreased model concurrency limit (some permits still held by in-flight requests)"
288                    );
289                } else {
290                    tracing::info!(
291                        model = %model,
292                        old_limit = *stored_limit,
293                        new_limit = current_limit,
294                        forgot_permits = actual_forgotten,
295                        "Decreased model concurrency limit"
296                    );
297                }
298
299                // Update to the new effective limit (accounting for unforgettable permits)
300                *stored_limit = current_limit + (desired_delta - actual_forgotten);
301            }
302        }
303
304        semaphore.clone()
305    }
306
307    /// Try to acquire a permit for a model (non-blocking).
308    async fn try_acquire_permit(&self, model: &str) -> Option<tokio::sync::OwnedSemaphorePermit> {
309        let semaphore = self.get_semaphore(model).await;
310        semaphore.clone().try_acquire_owned().ok()
311    }
312
313    /// Run the daemon loop.
314    ///
315    /// This continuously claims and processes requests until an error occurs
316    /// or the task is cancelled.
317    ///
318    /// The daemon periodically polls for cancelled batches and aborts in-flight requests.
319    #[tracing::instrument(skip(self), fields(daemon_id = %self.daemon_id))]
320    pub async fn run(self: Arc<Self>) -> Result<()> {
321        tracing::info!("Daemon starting main processing loop");
322
323        // Register daemon in database
324        let daemon_record = DaemonRecord {
325            data: DaemonData {
326                id: self.daemon_id,
327                hostname: types::get_hostname(),
328                pid: types::get_pid(),
329                version: types::get_version(),
330                config_snapshot: serde_json::to_value(&self.config)
331                    .expect("Failed to serialize daemon config"),
332            },
333            state: Initializing {
334                started_at: chrono::Utc::now(),
335            },
336        };
337
338        let running_record = daemon_record.start(self.storage.as_ref()).await?;
339        tracing::info!("Daemon registered in database");
340
341        // Spawn periodic heartbeat task
342        let storage = self.storage.clone();
343        let requests_in_flight = self.requests_in_flight.clone();
344        let requests_processed = self.requests_processed.clone();
345        let requests_failed = self.requests_failed.clone();
346        let daemon_id = self.daemon_id;
347        let heartbeat_interval_ms = self.config.heartbeat_interval_ms;
348        let shutdown_signal = self.shutdown_token.clone();
349
350        let heartbeat_handle = tokio::spawn(async move {
351            let mut interval = tokio::time::interval(Duration::from_millis(heartbeat_interval_ms));
352            let mut daemon_record = running_record;
353
354            loop {
355                tokio::select! {
356                    _ = interval.tick() => {
357                        let stats = DaemonStats {
358                            requests_processed: requests_processed.load(Ordering::Relaxed),
359                            requests_failed: requests_failed.load(Ordering::Relaxed),
360                            requests_in_flight: requests_in_flight.load(Ordering::Relaxed),
361                        };
362
363                        // Clone the record so we preserve it if heartbeat fails
364                        let current = daemon_record.clone();
365                        match current.heartbeat(stats, storage.as_ref()).await {
366                            Ok(updated) => {
367                                daemon_record = updated;
368                                tracing::trace!(
369                                    daemon_id = %daemon_id,
370                                    "Heartbeat sent"
371                                );
372                            }
373                            Err(e) => {
374                                tracing::error!(
375                                    daemon_id = %daemon_id,
376                                    error = %e,
377                                    "Failed to send heartbeat"
378                                );
379                                // daemon_record stays unchanged on error
380                            }
381                        }
382                    }
383                    _ = shutdown_signal.cancelled() => {
384                        // Mark daemon as dead on shutdown
385                        tracing::info!("Shutting down heartbeat task");
386                        if let Err(e) = daemon_record.shutdown(storage.as_ref()).await {
387                            tracing::error!(
388                                daemon_id = %daemon_id,
389                                error = %e,
390                                "Failed to mark daemon as dead during shutdown"
391                            );
392                        }
393                        break;
394                    }
395                }
396            }
397        });
398
399        // Spawn periodic status logging task if configured
400        if let Some(interval_ms) = self.config.status_log_interval_ms {
401            let requests_in_flight = self.requests_in_flight.clone();
402            let daemon_id = self.daemon_id;
403            tokio::spawn(async move {
404                let mut interval = tokio::time::interval(Duration::from_millis(interval_ms));
405                loop {
406                    interval.tick().await;
407                    let count = requests_in_flight.load(Ordering::Relaxed);
408                    tracing::debug!(
409                        daemon_id = %daemon_id,
410                        requests_in_flight = count,
411                        "Daemon status"
412                    );
413                }
414            });
415        }
416
417        // Spawn periodic task to poll for cancelled batches and abort in-flight requests
418        let cancellation_tokens = self.cancellation_tokens.clone();
419        let storage = self.storage.clone();
420        let shutdown_token = self.shutdown_token.clone();
421        let cancellation_poll_interval_ms = self.config.cancellation_poll_interval_ms;
422        tokio::spawn(async move {
423            let mut interval =
424                tokio::time::interval(Duration::from_millis(cancellation_poll_interval_ms));
425            tracing::info!(
426                interval_ms = cancellation_poll_interval_ms,
427                "Batch cancellation polling started"
428            );
429
430            loop {
431                tokio::select! {
432                    _ = interval.tick() => {
433                        // Get all active batch IDs we're currently processing
434                        let active_batch_ids: Vec<BatchId> = cancellation_tokens
435                            .iter()
436                            .map(|entry| *entry.key())
437                            .collect();
438
439                        if active_batch_ids.is_empty() {
440                            continue;
441                        }
442
443                        // Query database to check which of these batches have been cancelled
444                        // Note: DaemonStorage doesn't have a method for this, so we'll check via the batch
445                        // For now, we'll check each batch individually
446                        for batch_id in active_batch_ids {
447                            // Try to get the batch - if it has cancelling_at set, cancel the token
448                            if let Ok(batch) = storage.get_batch(batch_id).await
449                                && batch.cancelling_at.is_some()
450                                    && let Some(entry) = cancellation_tokens.get(&batch_id) {
451                                        entry.value().cancel();
452                                        tracing::info!(batch_id = %batch_id, "Cancelled all requests in batch");
453                                        // Remove from map so we don't keep checking it
454                                        drop(entry);
455                                        cancellation_tokens.remove(&batch_id);
456                                    }
457                        }
458                    }
459                    _ = shutdown_token.cancelled() => {
460                        tracing::info!("Shutting down cancellation polling");
461                        break;
462                    }
463                }
464            }
465        });
466
467        let mut join_set: JoinSet<Result<()>> = JoinSet::new();
468
469        let run_result = loop {
470            // Check for shutdown signal
471            if self.shutdown_token.is_cancelled() {
472                tracing::info!("Shutdown signal received, stopping daemon");
473                break Ok(());
474            }
475
476            // Poll for completed tasks (non-blocking)
477            while let Some(result) = join_set.try_join_next() {
478                match result {
479                    Ok(Ok(())) => {
480                        tracing::trace!("Task completed successfully");
481                    }
482                    Ok(Err(e)) => {
483                        tracing::error!(error = %e, "Task failed");
484                    }
485                    Err(join_error) => {
486                        tracing::error!(error = %join_error, "Task panicked");
487                    }
488                }
489            }
490
491            tracing::trace!("Sleeping before claiming");
492            tokio::select! {
493                _ = tokio::time::sleep(Duration::from_millis(self.config.claim_interval_ms)) => {},
494                _ = self.shutdown_token.cancelled() => {
495                    tracing::info!("Shutdown signal received, stopping daemon");
496                    break Ok(());
497                }
498            }
499            // Claim a batch of pending requests
500            let mut claimed = self
501                .storage
502                .claim_requests(self.config.claim_batch_size, self.daemon_id)
503                .await?;
504
505            // Record claim metrics
506            counter!("fusillade_claims_total").increment(claimed.len() as u64);
507
508            tracing::debug!(
509                claimed_count = claimed.len(),
510                "Claimed requests from storage"
511            );
512
513            // Route requests to escalated models if time is running low
514            // This replaces the old SLA racing system with a simpler approach:
515            // at claim time, we check if there's enough time remaining and route
516            // to the escalated model if below threshold
517            for request in &mut claimed {
518                if let Some(config) = self.config.model_escalations.get(&request.data.model) {
519                    let time_remaining = request.state.batch_expires_at - chrono::Utc::now();
520                    if time_remaining.num_seconds() < config.escalation_threshold_seconds {
521                        let original_model = request.data.model.clone();
522                        request.data.model = config.escalation_model.clone();
523
524                        // Update the model field in the request body JSON
525                        if let Ok(mut json) =
526                            serde_json::from_str::<serde_json::Value>(&request.data.body)
527                            && let Some(obj) = json.as_object_mut()
528                        {
529                            obj.insert(
530                                "model".to_string(),
531                                serde_json::Value::String(config.escalation_model.clone()),
532                            );
533                            if let Ok(new_body) = serde_json::to_string(&json) {
534                                request.data.body = new_body;
535                            }
536                        }
537
538                        // No API key swap needed - batch API keys automatically have access
539                        // to escalation models in the onwards routing cache
540                        counter!("fusillade_requests_routed_to_escalation_total", "original_model" => original_model.clone(), "escalation_model" => config.escalation_model.clone()).increment(1);
541                        tracing::info!(
542                            request_id = %request.data.id,
543                            original_model = %original_model,
544                            escalation_model = %config.escalation_model,
545                            time_remaining_seconds = time_remaining.num_seconds(),
546                            threshold_seconds = config.escalation_threshold_seconds,
547                            "Routing request to escalation model due to time pressure"
548                        );
549                    }
550                }
551            }
552
553            // Group requests by model for better concurrency control visibility
554            let mut by_model: HashMap<String, Vec<_>> = HashMap::new();
555            for request in claimed {
556                let model = request.data.model.clone();
557                by_model.entry(model).or_default().push(request);
558            }
559
560            tracing::debug!(
561                models = by_model.len(),
562                total_requests = by_model.values().map(|v| v.len()).sum::<usize>(),
563                "Grouped requests by model"
564            );
565
566            // Dispatch requests
567            for (model, requests) in by_model {
568                tracing::debug!(model = %model, count = requests.len(), "Processing requests for model");
569
570                for request in requests {
571                    let request_id = request.data.id;
572                    let batch_id = request.data.batch_id;
573
574                    // Try to acquire a semaphore permit for this model
575                    match self.try_acquire_permit(&model).await {
576                        Some(permit) => {
577                            tracing::debug!(
578                                request_id = %request_id,
579                                batch_id = %batch_id,
580                                model = %model,
581                                "Acquired permit, spawning processing task"
582                            );
583
584                            // We have capacity - spawn a task
585                            let model_clone = model.clone(); // Clone model for the spawned task
586                            let storage = self.storage.clone();
587                            let http_client = (*self.http_client).clone();
588                            let timeout_ms = self.config.timeout_ms;
589                            let retry_config = (&self.config).into();
590                            let requests_in_flight = self.requests_in_flight.clone();
591                            let requests_processed = self.requests_processed.clone();
592                            let requests_failed = self.requests_failed.clone();
593                            let should_retry = self.config.should_retry.clone();
594                            let shutdown_token = self.shutdown_token.clone();
595                            let cancellation_tokens = self.cancellation_tokens.clone();
596
597                            // Get or create a cancellation token for this batch
598                            // All requests in a batch share the same token
599                            let batch_cancellation_token =
600                                cancellation_tokens.entry(batch_id).or_default().clone();
601
602                            // Increment in-flight counter and gauge
603                            requests_in_flight.fetch_add(1, Ordering::Relaxed);
604                            gauge!("fusillade_requests_in_flight", "model" => model_clone.clone())
605                                .increment(1.0);
606
607                            join_set.spawn(async move {
608                                // Permit is held for the duration of this task
609                                let _permit = permit;
610
611                                // Track processing start time for duration metrics
612                                let processing_start = std::time::Instant::now();
613
614                                // Ensure we decrement the counter when this task completes
615                                let model_for_guard = model_clone.clone();
616                                let _guard = scopeguard::guard((), move |_| {
617                                    requests_in_flight.fetch_sub(1, Ordering::Relaxed);
618                                    gauge!("fusillade_requests_in_flight", "model" => model_for_guard).decrement(1.0);
619                                });
620
621                                tracing::info!(request_id = %request_id, "Processing request");
622
623                                // Launch request processing (this goes on a background thread)
624                                let processing = request.process(
625                                    http_client,
626                                    timeout_ms,
627                                    storage.as_ref()
628                                ).await?;
629
630                                // Capture retry attempt count before completion (not preserved in Completed state)
631                                let retry_attempt_at_completion = processing.state.retry_attempt;
632
633                                let cancellation = async {
634                                    tokio::select! {
635                                        _ = batch_cancellation_token.cancelled() => {
636                                            crate::request::transitions::CancellationReason::User
637                                        }
638                                        _ = shutdown_token.cancelled() => {
639                                            crate::request::transitions::CancellationReason::Shutdown
640                                        }
641                                    }
642                                };
643
644                                // Wait for completion
645                                match processing.complete(storage.as_ref(), |response| {
646                                    (should_retry)(response)
647                                }, cancellation).await {
648                                    Ok(RequestCompletionResult::Completed(_completed)) => {
649                                        requests_processed.fetch_add(1, Ordering::Relaxed);
650                                        counter!("fusillade_requests_completed_total", "model" => model_clone.clone(), "status" => "success").increment(1);
651                                        histogram!("fusillade_request_duration_seconds", "model" => model_clone.clone(), "status" => "success")
652                                            .record(processing_start.elapsed().as_secs_f64());
653                                        // Record how many retries it took to succeed (0 = first attempt succeeded)
654                                        histogram!("fusillade_retry_attempts_on_success", "model" => model_clone.clone())
655                                            .record(retry_attempt_at_completion as f64);
656                                        tracing::info!(request_id = %request_id, retry_attempts = retry_attempt_at_completion, "Request completed successfully");
657                                    }
658                                    Ok(RequestCompletionResult::Failed(failed)) => {
659                                        let retry_attempt = failed.state.retry_attempt;
660
661                                        // Check if this is a retriable error using the FailureReason
662                                        if failed.state.reason.is_retriable() {
663                                            tracing::warn!(
664                                                request_id = %request_id,
665                                                retry_attempt,
666                                                error = %failed.state.reason.to_error_message(),
667                                                "Request failed with retriable error, attempting retry"
668                                            );
669
670                                            // Attempt to retry
671                                            match failed.can_retry(retry_attempt, retry_config) {
672                                                Ok(pending) => {
673                                                    // Can retry - persist as Pending
674                                                    storage.persist(&pending).await?;
675                                                    counter!(
676                                                        "fusillade_requests_retried_total",
677                                                        "model" => model_clone.clone(),
678                                                        "attempt" => (retry_attempt + 1).to_string()
679                                                    ).increment(1);
680                                                    tracing::info!(
681                                                        request_id = %request_id,
682                                                        retry_attempt = retry_attempt + 1,
683                                                        "Request queued for retry"
684                                                    );
685                                                }
686                                                Err(failed) => {
687                                                    // No retries left - persist as Failed (terminal)
688                                                    storage.persist(&*failed).await?;
689                                                    requests_failed.fetch_add(1, Ordering::Relaxed);
690                                                    counter!("fusillade_requests_completed_total", "model" => model_clone.clone(), "status" => "failed").increment(1);
691                                                    histogram!("fusillade_request_duration_seconds", "model" => model_clone.clone(), "status" => "failed")
692                                                        .record(processing_start.elapsed().as_secs_f64());
693                                                    tracing::warn!(
694                                                        request_id = %request_id,
695                                                        retry_attempt,
696                                                        "Request failed permanently (no retries remaining)"
697                                                    );
698                                                }
699                                            }
700                                        } else {
701                                            requests_failed.fetch_add(1, Ordering::Relaxed);
702                                            counter!("fusillade_requests_completed_total", "model" => model_clone.clone(), "status" => "failed").increment(1);
703                                            histogram!("fusillade_request_duration_seconds", "model" => model_clone.clone(), "status" => "failed")
704                                                .record(processing_start.elapsed().as_secs_f64());
705                                            tracing::warn!(
706                                                request_id = %request_id,
707                                                error = %failed.state.reason.to_error_message(),
708                                                "Request failed with non-retriable error, not retrying"
709                                            );
710                                        }
711                                    }
712                                    Ok(RequestCompletionResult::Canceled(_canceled)) => {
713                                        counter!("fusillade_requests_completed_total", "model" => model_clone.clone(), "status" => "cancelled").increment(1);
714                                        tracing::debug!(request_id = %request_id, "Request canceled by user");
715                                    }
716                                    Err(FusilladeError::Shutdown) => {
717                                        tracing::info!(request_id = %request_id, "Request aborted due to shutdown");
718                                        // Don't count as failed - request will be reclaimed
719                                    }
720                                    Err(e) => {
721                                        // Unexpected error
722                                        tracing::error!(request_id = %request_id, error = %e, "Unexpected error processing request");
723                                        return Err(e);
724                                    }
725                                }
726
727                                // Note: We don't remove the batch cancellation token here since
728                                // multiple requests in the same batch share it. Tokens are cleaned
729                                // up when the daemon shuts down or batch completes.
730
731                                Ok(())
732                            });
733                        }
734                        None => {
735                            tracing::debug!(
736                                request_id = %request_id,
737                                model = %model,
738                                "No capacity available, unclaiming request"
739                            );
740
741                            // No capacity for this model - unclaim the request
742                            let storage = self.storage.clone();
743                            if let Err(e) = request.unclaim(storage.as_ref()).await {
744                                tracing::error!(
745                                    request_id = %request_id,
746                                    error = %e,
747                                    "Failed to unclaim request"
748                                );
749                            };
750                        }
751                    }
752                }
753            }
754        };
755
756        // Wait for heartbeat task to complete (it will mark daemon as dead)
757        tracing::info!("Waiting for heartbeat task to complete");
758        if let Err(e) = heartbeat_handle.await {
759            tracing::error!(error = %e, "Heartbeat task panicked");
760        }
761
762        run_result
763    }
764}
765
766#[cfg(test)]
767mod tests {
768    use super::*;
769    use crate::TestDbPools;
770    use crate::http::{HttpResponse, MockHttpClient};
771    use crate::manager::{DaemonExecutor, postgres::PostgresRequestManager};
772    use std::time::Duration;
773
774    #[sqlx::test]
775    #[test_log::test]
776    async fn test_daemon_claims_and_completes_request(pool: sqlx::PgPool) {
777        // Setup: Create HTTP client with mock response
778        let http_client = Arc::new(MockHttpClient::new());
779        http_client.add_response(
780            "POST /v1/test",
781            Ok(HttpResponse {
782                status: 200,
783                body: r#"{"result":"success"}"#.to_string(),
784            }),
785        );
786
787        // Setup: Create manager with fast claim interval (no sleeping)
788        let config = DaemonConfig {
789            claim_batch_size: 10,
790            claim_interval_ms: 10, // Very fast for testing
791            default_model_concurrency: 10,
792            model_concurrency_limits: Arc::new(dashmap::DashMap::new()),
793            model_escalations: Arc::new(dashmap::DashMap::new()),
794            max_retries: Some(3),
795            stop_before_deadline_ms: None,
796            backoff_ms: 100,
797            backoff_factor: 2,
798            max_backoff_ms: 1000,
799            timeout_ms: 5000,
800            status_log_interval_ms: None, // Disable status logging in tests
801            heartbeat_interval_ms: 10000, // 10 seconds
802            should_retry: Arc::new(default_should_retry),
803            claim_timeout_ms: 60000,
804            processing_timeout_ms: 600000,
805            unclaim_batch_size: 100,
806            batch_metadata_fields: vec![],
807            cancellation_poll_interval_ms: 100, // Fast polling for tests
808        };
809
810        let manager = Arc::new(
811            PostgresRequestManager::with_client(
812                TestDbPools::new(pool.clone()).await.unwrap(),
813                http_client.clone(),
814            )
815            .with_config(config),
816        );
817
818        // Setup: Create a file and batch to associate with our request
819        let file_id = manager
820            .create_file(
821                "test-file".to_string(),
822                Some("Test file".to_string()),
823                vec![crate::RequestTemplateInput {
824                    custom_id: None,
825                    endpoint: "https://api.example.com".to_string(),
826                    method: "POST".to_string(),
827                    path: "/v1/test".to_string(),
828                    body: r#"{"prompt":"test"}"#.to_string(),
829                    model: "test-model".to_string(),
830                    api_key: "test-key".to_string(),
831                }],
832            )
833            .await
834            .expect("Failed to create file");
835
836        let batch = manager
837            .create_batch(crate::batch::BatchInput {
838                file_id,
839                endpoint: "/v1/chat/completions".to_string(),
840                completion_window: "24h".to_string(),
841                metadata: None,
842                created_by: None,
843            })
844            .await
845            .expect("Failed to create batch");
846
847        // Get the created request from the batch
848        let requests = manager
849            .get_batch_requests(batch.id)
850            .await
851            .expect("Failed to get batch requests");
852        assert_eq!(requests.len(), 1);
853        let request_id = requests[0].id();
854
855        // Start the daemon
856        let shutdown_token = tokio_util::sync::CancellationToken::new();
857        manager
858            .clone()
859            .run(shutdown_token.clone())
860            .expect("Failed to start daemon");
861
862        // Poll for completion (with timeout)
863        let start = tokio::time::Instant::now();
864        let timeout = Duration::from_secs(5);
865        let mut completed = false;
866
867        while start.elapsed() < timeout {
868            let results = manager
869                .get_requests(vec![request_id])
870                .await
871                .expect("Failed to get request");
872
873            if let Some(Ok(any_request)) = results.first()
874                && any_request.is_terminal()
875            {
876                if let crate::AnyRequest::Completed(req) = any_request {
877                    // Verify the request was completed successfully
878                    assert_eq!(req.state.response_status, 200);
879                    assert_eq!(req.state.response_body, r#"{"result":"success"}"#);
880                    completed = true;
881                    break;
882                } else {
883                    panic!(
884                        "Request reached terminal state but was not completed: {:?}",
885                        any_request
886                    );
887                }
888            }
889
890            tokio::time::sleep(Duration::from_millis(50)).await;
891        }
892
893        // Stop the daemon
894        shutdown_token.cancel();
895
896        // Assert that the request completed
897        assert!(
898            completed,
899            "Request did not complete within timeout. Check daemon processing."
900        );
901
902        // Verify HTTP client was called exactly once
903        assert_eq!(http_client.call_count(), 1);
904        let calls = http_client.get_calls();
905        assert_eq!(calls[0].method, "POST");
906        assert_eq!(calls[0].path, "/v1/test");
907        assert_eq!(calls[0].api_key, "test-key");
908    }
909
910    #[sqlx::test]
911    async fn test_daemon_respects_per_model_concurrency_limits(pool: sqlx::PgPool) {
912        // Setup: Create HTTP client with triggered responses
913        let http_client = Arc::new(MockHttpClient::new());
914
915        // Add 5 triggered responses for our 5 requests
916        let trigger1 = http_client.add_response_with_trigger(
917            "POST /v1/test",
918            Ok(HttpResponse {
919                status: 200,
920                body: r#"{"result":"1"}"#.to_string(),
921            }),
922        );
923        let trigger2 = http_client.add_response_with_trigger(
924            "POST /v1/test",
925            Ok(HttpResponse {
926                status: 200,
927                body: r#"{"result":"2"}"#.to_string(),
928            }),
929        );
930        let trigger3 = http_client.add_response_with_trigger(
931            "POST /v1/test",
932            Ok(HttpResponse {
933                status: 200,
934                body: r#"{"result":"3"}"#.to_string(),
935            }),
936        );
937        let trigger4 = http_client.add_response_with_trigger(
938            "POST /v1/test",
939            Ok(HttpResponse {
940                status: 200,
941                body: r#"{"result":"4"}"#.to_string(),
942            }),
943        );
944        let trigger5 = http_client.add_response_with_trigger(
945            "POST /v1/test",
946            Ok(HttpResponse {
947                status: 200,
948                body: r#"{"result":"5"}"#.to_string(),
949            }),
950        );
951
952        // Setup: Create manager with concurrency limit of 2 for "gpt-4"
953        let model_concurrency_limits = Arc::new(dashmap::DashMap::new());
954        model_concurrency_limits.insert("gpt-4".to_string(), 2);
955
956        let config = DaemonConfig {
957            claim_batch_size: 10,
958            claim_interval_ms: 10,
959            default_model_concurrency: 10,
960            model_concurrency_limits,
961            model_escalations: Arc::new(dashmap::DashMap::new()),
962            max_retries: Some(3),
963            stop_before_deadline_ms: None,
964            backoff_ms: 100,
965            backoff_factor: 2,
966            max_backoff_ms: 1000,
967            timeout_ms: 5000,
968            status_log_interval_ms: None,
969            heartbeat_interval_ms: 10000,
970            should_retry: Arc::new(default_should_retry),
971            claim_timeout_ms: 60000,
972            processing_timeout_ms: 600000,
973            unclaim_batch_size: 100,
974            batch_metadata_fields: vec![],
975            cancellation_poll_interval_ms: 100, // Fast polling for tests
976        };
977
978        let manager = Arc::new(
979            PostgresRequestManager::with_client(
980                TestDbPools::new(pool.clone()).await.unwrap(),
981                http_client.clone(),
982            )
983            .with_config(config),
984        );
985
986        // Setup: Create a file with 5 templates, all using "gpt-4"
987        let file_id = manager
988            .create_file(
989                "test-file".to_string(),
990                Some("Test concurrency limits".to_string()),
991                vec![
992                    crate::RequestTemplateInput {
993                        custom_id: None,
994                        endpoint: "https://api.example.com".to_string(),
995                        method: "POST".to_string(),
996                        path: "/v1/test".to_string(),
997                        body: r#"{"prompt":"test1"}"#.to_string(),
998                        model: "gpt-4".to_string(),
999                        api_key: "test-key".to_string(),
1000                    },
1001                    crate::RequestTemplateInput {
1002                        custom_id: None,
1003                        endpoint: "https://api.example.com".to_string(),
1004                        method: "POST".to_string(),
1005                        path: "/v1/test".to_string(),
1006                        body: r#"{"prompt":"test2"}"#.to_string(),
1007                        model: "gpt-4".to_string(),
1008                        api_key: "test-key".to_string(),
1009                    },
1010                    crate::RequestTemplateInput {
1011                        custom_id: None,
1012                        endpoint: "https://api.example.com".to_string(),
1013                        method: "POST".to_string(),
1014                        path: "/v1/test".to_string(),
1015                        body: r#"{"prompt":"test3"}"#.to_string(),
1016                        model: "gpt-4".to_string(),
1017                        api_key: "test-key".to_string(),
1018                    },
1019                    crate::RequestTemplateInput {
1020                        custom_id: None,
1021                        endpoint: "https://api.example.com".to_string(),
1022                        method: "POST".to_string(),
1023                        path: "/v1/test".to_string(),
1024                        body: r#"{"prompt":"test4"}"#.to_string(),
1025                        model: "gpt-4".to_string(),
1026                        api_key: "test-key".to_string(),
1027                    },
1028                    crate::RequestTemplateInput {
1029                        custom_id: None,
1030                        endpoint: "https://api.example.com".to_string(),
1031                        method: "POST".to_string(),
1032                        path: "/v1/test".to_string(),
1033                        body: r#"{"prompt":"test5"}"#.to_string(),
1034                        model: "gpt-4".to_string(),
1035                        api_key: "test-key".to_string(),
1036                    },
1037                ],
1038            )
1039            .await
1040            .expect("Failed to create file");
1041
1042        let batch = manager
1043            .create_batch(crate::batch::BatchInput {
1044                file_id,
1045                endpoint: "/v1/chat/completions".to_string(),
1046                completion_window: "24h".to_string(),
1047                metadata: None,
1048                created_by: None,
1049            })
1050            .await
1051            .expect("Failed to create batch");
1052
1053        // Start the daemon
1054        let shutdown_token = tokio_util::sync::CancellationToken::new();
1055        manager
1056            .clone()
1057            .run(shutdown_token.clone())
1058            .expect("Failed to start daemon");
1059
1060        // Wait for exactly 2 requests to be in-flight (respecting concurrency limit)
1061        let start = tokio::time::Instant::now();
1062        let timeout = Duration::from_secs(2);
1063        let mut reached_limit = false;
1064
1065        while start.elapsed() < timeout {
1066            let in_flight = http_client.in_flight_count();
1067            if in_flight == 2 {
1068                reached_limit = true;
1069                break;
1070            }
1071            tokio::time::sleep(Duration::from_millis(10)).await;
1072        }
1073
1074        assert!(
1075            reached_limit,
1076            "Expected exactly 2 requests in-flight, got {}",
1077            http_client.in_flight_count()
1078        );
1079
1080        // Verify exactly 2 are in-flight (not more)
1081        tokio::time::sleep(Duration::from_millis(100)).await;
1082        assert_eq!(
1083            http_client.in_flight_count(),
1084            2,
1085            "Concurrency limit violated: more than 2 requests in-flight"
1086        );
1087
1088        // Trigger completion of first request
1089        trigger1.send(()).unwrap();
1090
1091        // Wait for the third request to start
1092        let start = tokio::time::Instant::now();
1093        let timeout = Duration::from_secs(2);
1094        let mut third_started = false;
1095
1096        while start.elapsed() < timeout {
1097            if http_client.call_count() >= 3 {
1098                third_started = true;
1099                break;
1100            }
1101            tokio::time::sleep(Duration::from_millis(10)).await;
1102        }
1103
1104        assert!(
1105            third_started,
1106            "Third request should have started after first completed"
1107        );
1108
1109        // Verify still only 2 in-flight
1110        assert_eq!(
1111            http_client.in_flight_count(),
1112            2,
1113            "Should maintain concurrency limit of 2"
1114        );
1115
1116        // Complete remaining requests to clean up
1117        trigger2.send(()).unwrap();
1118        trigger3.send(()).unwrap();
1119        trigger4.send(()).unwrap();
1120        trigger5.send(()).unwrap();
1121
1122        // Wait for all requests to complete
1123        let start = tokio::time::Instant::now();
1124        let timeout = Duration::from_secs(5);
1125        let mut all_completed = false;
1126
1127        while start.elapsed() < timeout {
1128            let status = manager
1129                .get_batch_status(batch.id)
1130                .await
1131                .expect("Failed to get batch status");
1132
1133            if status.completed_requests == 5 {
1134                all_completed = true;
1135                break;
1136            }
1137            tokio::time::sleep(Duration::from_millis(50)).await;
1138        }
1139
1140        // Stop the daemon
1141        shutdown_token.cancel();
1142
1143        assert!(all_completed, "All 5 requests should have completed");
1144
1145        // Verify all 5 HTTP calls were made
1146        assert_eq!(http_client.call_count(), 5);
1147    }
1148
1149    #[sqlx::test]
1150    async fn test_daemon_retries_failed_requests(pool: sqlx::PgPool) {
1151        // Setup: Create HTTP client with failing responses, then success
1152        let http_client = Arc::new(MockHttpClient::new());
1153
1154        // First attempt: fails with 500
1155        http_client.add_response(
1156            "POST /v1/test",
1157            Ok(HttpResponse {
1158                status: 500,
1159                body: r#"{"error":"internal error"}"#.to_string(),
1160            }),
1161        );
1162
1163        // Second attempt: fails with 503
1164        http_client.add_response(
1165            "POST /v1/test",
1166            Ok(HttpResponse {
1167                status: 503,
1168                body: r#"{"error":"service unavailable"}"#.to_string(),
1169            }),
1170        );
1171
1172        // Third attempt: succeeds
1173        http_client.add_response(
1174            "POST /v1/test",
1175            Ok(HttpResponse {
1176                status: 200,
1177                body: r#"{"result":"success after retries"}"#.to_string(),
1178            }),
1179        );
1180
1181        // Setup: Create manager with fast backoff for testing
1182        let config = DaemonConfig {
1183            claim_batch_size: 10,
1184            claim_interval_ms: 10,
1185            default_model_concurrency: 10,
1186            model_concurrency_limits: Arc::new(dashmap::DashMap::new()),
1187            model_escalations: Arc::new(dashmap::DashMap::new()),
1188            max_retries: Some(5),
1189            stop_before_deadline_ms: None,
1190            backoff_ms: 10, // Very fast backoff for testing
1191            backoff_factor: 2,
1192            max_backoff_ms: 100,
1193            timeout_ms: 5000,
1194            status_log_interval_ms: None,
1195            heartbeat_interval_ms: 10000,
1196            should_retry: Arc::new(default_should_retry),
1197            claim_timeout_ms: 60000,
1198            processing_timeout_ms: 600000,
1199            unclaim_batch_size: 100,
1200            batch_metadata_fields: vec![],
1201            cancellation_poll_interval_ms: 100, // Fast polling for tests
1202        };
1203
1204        let manager = Arc::new(
1205            PostgresRequestManager::with_client(
1206                TestDbPools::new(pool.clone()).await.unwrap(),
1207                http_client.clone(),
1208            )
1209            .with_config(config),
1210        );
1211
1212        // Setup: Create a file and batch
1213        let file_id = manager
1214            .create_file(
1215                "test-file".to_string(),
1216                Some("Test retry logic".to_string()),
1217                vec![crate::RequestTemplateInput {
1218                    custom_id: None,
1219                    endpoint: "https://api.example.com".to_string(),
1220                    method: "POST".to_string(),
1221                    path: "/v1/test".to_string(),
1222                    body: r#"{"prompt":"test"}"#.to_string(),
1223                    model: "test-model".to_string(),
1224                    api_key: "test-key".to_string(),
1225                }],
1226            )
1227            .await
1228            .expect("Failed to create file");
1229
1230        let batch = manager
1231            .create_batch(crate::batch::BatchInput {
1232                file_id,
1233                endpoint: "/v1/chat/completions".to_string(),
1234                completion_window: "24h".to_string(),
1235                metadata: None,
1236                created_by: None,
1237            })
1238            .await
1239            .expect("Failed to create batch");
1240
1241        let requests = manager
1242            .get_batch_requests(batch.id)
1243            .await
1244            .expect("Failed to get batch requests");
1245        assert_eq!(requests.len(), 1);
1246        let request_id = requests[0].id();
1247
1248        // Start the daemon
1249        let shutdown_token = tokio_util::sync::CancellationToken::new();
1250        manager
1251            .clone()
1252            .run(shutdown_token.clone())
1253            .expect("Failed to start daemon");
1254
1255        // Poll for completion (with timeout)
1256        let start = tokio::time::Instant::now();
1257        let timeout = Duration::from_secs(5);
1258        let mut completed = false;
1259
1260        while start.elapsed() < timeout {
1261            let results = manager
1262                .get_requests(vec![request_id])
1263                .await
1264                .expect("Failed to get request");
1265
1266            if let Some(Ok(any_request)) = results.first()
1267                && let crate::AnyRequest::Completed(req) = any_request
1268            {
1269                // Verify the request eventually completed successfully
1270                assert_eq!(req.state.response_status, 200);
1271                assert_eq!(
1272                    req.state.response_body,
1273                    r#"{"result":"success after retries"}"#
1274                );
1275                completed = true;
1276                break;
1277            }
1278
1279            tokio::time::sleep(Duration::from_millis(50)).await;
1280        }
1281
1282        // Stop the daemon
1283        shutdown_token.cancel();
1284
1285        assert!(completed, "Request should have completed after retries");
1286
1287        // Verify the request was attempted 3 times (2 failures + 1 success)
1288        assert_eq!(
1289            http_client.call_count(),
1290            3,
1291            "Expected 3 HTTP calls (2 failed attempts + 1 success)"
1292        );
1293    }
1294
1295    #[sqlx::test]
1296    async fn test_daemon_dynamically_updates_concurrency_limits(pool: sqlx::PgPool) {
1297        // Setup: Create HTTP client with triggered responses
1298        let http_client = Arc::new(MockHttpClient::new());
1299
1300        // Add 10 triggered responses
1301        let mut triggers = vec![];
1302        for i in 1..=10 {
1303            let trigger = http_client.add_response_with_trigger(
1304                "POST /v1/test",
1305                Ok(HttpResponse {
1306                    status: 200,
1307                    body: format!(r#"{{"result":"{}"}}"#, i),
1308                }),
1309            );
1310            triggers.push(trigger);
1311        }
1312
1313        // Setup: Start with concurrency limit of 2 for "gpt-4"
1314        let model_concurrency_limits = Arc::new(dashmap::DashMap::new());
1315        model_concurrency_limits.insert("gpt-4".to_string(), 2);
1316
1317        let config = DaemonConfig {
1318            claim_batch_size: 10,
1319            claim_interval_ms: 10,
1320            default_model_concurrency: 10,
1321            model_concurrency_limits: model_concurrency_limits.clone(),
1322            model_escalations: Arc::new(dashmap::DashMap::new()),
1323            max_retries: Some(3),
1324            stop_before_deadline_ms: None,
1325            backoff_ms: 100,
1326            backoff_factor: 2,
1327            max_backoff_ms: 1000,
1328            timeout_ms: 5000,
1329            status_log_interval_ms: None,
1330            heartbeat_interval_ms: 10000,
1331            should_retry: Arc::new(default_should_retry),
1332            claim_timeout_ms: 60000,
1333            processing_timeout_ms: 600000,
1334            unclaim_batch_size: 100,
1335            batch_metadata_fields: vec![],
1336            cancellation_poll_interval_ms: 100, // Fast polling for tests
1337        };
1338
1339        let manager = Arc::new(
1340            PostgresRequestManager::with_client(
1341                TestDbPools::new(pool.clone()).await.unwrap(),
1342                http_client.clone(),
1343            )
1344            .with_config(config),
1345        );
1346
1347        // Setup: Create a file with 10 requests, all using "gpt-4"
1348        let templates: Vec<_> = (1..=10)
1349            .map(|i| crate::RequestTemplateInput {
1350                custom_id: None,
1351                endpoint: "https://api.example.com".to_string(),
1352                method: "POST".to_string(),
1353                path: "/v1/test".to_string(),
1354                body: format!(r#"{{"prompt":"test{}"}}"#, i),
1355                model: "gpt-4".to_string(),
1356                api_key: "test-key".to_string(),
1357            })
1358            .collect();
1359
1360        let file_id = manager
1361            .create_file(
1362                "test-file".to_string(),
1363                Some("Test dynamic limits".to_string()),
1364                templates,
1365            )
1366            .await
1367            .expect("Failed to create file");
1368
1369        let batch = manager
1370            .create_batch(crate::batch::BatchInput {
1371                file_id,
1372                endpoint: "/v1/chat/completions".to_string(),
1373                completion_window: "24h".to_string(),
1374                metadata: None,
1375                created_by: None,
1376            })
1377            .await
1378            .expect("Failed to create batch");
1379
1380        // Start the daemon
1381        let shutdown_token = tokio_util::sync::CancellationToken::new();
1382        manager
1383            .clone()
1384            .run(shutdown_token.clone())
1385            .expect("Failed to start daemon");
1386
1387        // Wait for exactly 2 requests to be in-flight (initial limit)
1388        let start = tokio::time::Instant::now();
1389        let timeout = Duration::from_secs(2);
1390        let mut reached_initial_limit = false;
1391
1392        while start.elapsed() < timeout {
1393            let in_flight = http_client.in_flight_count();
1394            if in_flight == 2 {
1395                reached_initial_limit = true;
1396                break;
1397            }
1398            tokio::time::sleep(Duration::from_millis(10)).await;
1399        }
1400
1401        assert!(
1402            reached_initial_limit,
1403            "Expected exactly 2 requests in-flight with initial limit"
1404        );
1405
1406        // Increase the limit to 5
1407        model_concurrency_limits.insert("gpt-4".to_string(), 5);
1408
1409        // Wait a bit for the daemon to pick up the new limit
1410        tokio::time::sleep(Duration::from_millis(100)).await;
1411
1412        // Complete one request to free up a permit and trigger daemon to check limits
1413        triggers.remove(0).send(()).unwrap();
1414        tokio::time::sleep(Duration::from_millis(50)).await;
1415
1416        // Now we should see up to 5 requests in flight
1417        let start = tokio::time::Instant::now();
1418        let timeout = Duration::from_secs(2);
1419        let mut reached_new_limit = false;
1420
1421        while start.elapsed() < timeout {
1422            let in_flight = http_client.in_flight_count();
1423            if in_flight >= 4 {
1424                // Should see at least 4-5 in flight with new limit
1425                reached_new_limit = true;
1426                break;
1427            }
1428            tokio::time::sleep(Duration::from_millis(10)).await;
1429        }
1430
1431        assert!(
1432            reached_new_limit,
1433            "Expected more requests in-flight after limit increase, got {}",
1434            http_client.in_flight_count()
1435        );
1436
1437        // Now decrease the limit to 3
1438        model_concurrency_limits.insert("gpt-4".to_string(), 3);
1439
1440        // Complete remaining requests
1441        for trigger in triggers {
1442            trigger.send(()).unwrap();
1443        }
1444
1445        // Wait for all requests to complete
1446        let start = tokio::time::Instant::now();
1447        let timeout = Duration::from_secs(5);
1448        let mut all_completed = false;
1449
1450        while start.elapsed() < timeout {
1451            let status = manager
1452                .get_batch_status(batch.id)
1453                .await
1454                .expect("Failed to get batch status");
1455
1456            if status.completed_requests == 10 {
1457                all_completed = true;
1458                break;
1459            }
1460            tokio::time::sleep(Duration::from_millis(50)).await;
1461        }
1462
1463        // Stop the daemon
1464        shutdown_token.cancel();
1465
1466        assert!(all_completed, "All 10 requests should have completed");
1467        assert_eq!(http_client.call_count(), 10);
1468    }
1469
1470    #[sqlx::test]
1471    async fn test_deadline_aware_retry_stops_before_deadline(pool: sqlx::PgPool) {
1472        // Test that retries stop when approaching the deadline
1473        let http_client = Arc::new(MockHttpClient::new());
1474
1475        // All requests will fail
1476        for _ in 0..20 {
1477            http_client.add_response(
1478                "POST /v1/test",
1479                Ok(HttpResponse {
1480                    status: 500,
1481                    body: r#"{"error":"server error"}"#.to_string(),
1482                }),
1483            );
1484        }
1485
1486        // Use deadline-aware retry with a short completion window and short buffer
1487        let config = DaemonConfig {
1488            claim_batch_size: 10,
1489            claim_interval_ms: 10,
1490            default_model_concurrency: 10,
1491            model_concurrency_limits: Arc::new(dashmap::DashMap::new()),
1492            model_escalations: Arc::new(dashmap::DashMap::new()),
1493            max_retries: Some(10_000),
1494            stop_before_deadline_ms: Some(500), // 500ms buffer before deadline
1495            backoff_ms: 50,
1496            backoff_factor: 2,
1497            max_backoff_ms: 200,
1498            timeout_ms: 5000,
1499            status_log_interval_ms: None,
1500            heartbeat_interval_ms: 10000,
1501            should_retry: Arc::new(default_should_retry),
1502            claim_timeout_ms: 60000,
1503            processing_timeout_ms: 600000,
1504            unclaim_batch_size: 100,
1505            batch_metadata_fields: vec![],
1506            cancellation_poll_interval_ms: 100,
1507        };
1508
1509        let manager = Arc::new(
1510            PostgresRequestManager::with_client(
1511                TestDbPools::new(pool.clone()).await.unwrap(),
1512                http_client.clone(),
1513            )
1514            .with_config(config),
1515        );
1516
1517        // Create a batch with a very short completion window (2 seconds)
1518        let file_id = manager
1519            .create_file(
1520                "test-file".to_string(),
1521                Some("Test deadline cutoff".to_string()),
1522                vec![crate::RequestTemplateInput {
1523                    custom_id: None,
1524                    endpoint: "https://api.example.com".to_string(),
1525                    method: "POST".to_string(),
1526                    path: "/v1/test".to_string(),
1527                    body: r#"{"prompt":"test"}"#.to_string(),
1528                    model: "test-model".to_string(),
1529                    api_key: "test-key".to_string(),
1530                }],
1531            )
1532            .await
1533            .expect("Failed to create file");
1534
1535        let batch = manager
1536            .create_batch(crate::batch::BatchInput {
1537                file_id,
1538                endpoint: "/v1/chat/completions".to_string(),
1539                completion_window: "2s".to_string(), // Very short window
1540                metadata: None,
1541                created_by: None,
1542            })
1543            .await
1544            .expect("Failed to create batch");
1545
1546        let requests = manager
1547            .get_batch_requests(batch.id)
1548            .await
1549            .expect("Failed to get batch requests");
1550        let request_id = requests[0].id();
1551
1552        // Start the daemon
1553        let shutdown_token = tokio_util::sync::CancellationToken::new();
1554        manager
1555            .clone()
1556            .run(shutdown_token.clone())
1557            .expect("Failed to start daemon");
1558
1559        // Wait for the deadline to pass
1560        tokio::time::sleep(Duration::from_secs(3)).await;
1561
1562        // Check the request state
1563        let results = manager
1564            .get_requests(vec![request_id])
1565            .await
1566            .expect("Failed to get request");
1567
1568        shutdown_token.cancel();
1569
1570        if let Some(Ok(crate::AnyRequest::Failed(failed))) = results.first() {
1571            // Calculate expected retry attempts:
1572            // - Completion window: 2000ms
1573            // - Buffer: 500ms
1574            // - Effective deadline: 1500ms
1575            // - Backoff sequence: 50ms, 100ms, 200ms, 200ms, 200ms, 200ms, 200ms
1576            // - Timeline:
1577            //   - Initial attempt: t=0ms (attempt 0)
1578            //   - Retry 1: t=50ms (attempt 1)
1579            //   - Retry 2: t=150ms (attempt 2)
1580            //   - Retry 3: t=350ms (attempt 3)
1581            //   - Retry 4: t=550ms (attempt 4)
1582            //   - Retry 5: t=750ms (attempt 5)
1583            //   - Retry 6: t=950ms (attempt 6)
1584            //   - Retry 7: t=1150ms (attempt 7)
1585            //   - Retry 8: t=1350ms (attempt 8)
1586            //   - Next would be t=1550ms - EXCEEDS 1500ms deadline
1587            // Expected: 8 retry attempts (9 total including initial)
1588
1589            let retry_count = failed.state.retry_attempt;
1590            let call_count = http_client.call_count();
1591
1592            // 1. Verify we stopped before too many retries (deadline constraint)
1593            // Allow 4-9 attempts to account for timing variations in test execution,
1594            // parallel test execution overhead, and query overhead from batch metadata fields
1595            assert!(
1596                (4..=9).contains(&retry_count),
1597                "Expected 4-9 retry attempts based on deadline and backoff calculation, got {}",
1598                retry_count
1599            );
1600
1601            // 2. Verify HTTP call count matches retry attempts (1 initial + N retries)
1602            assert_eq!(
1603                call_count,
1604                (retry_count + 1) as usize,
1605                "Expected call count to match retry attempts + 1 initial attempt, got {} calls for {} retry attempts",
1606                call_count,
1607                retry_count
1608            );
1609
1610            // 3. Verify the request actually has error details from the last attempt
1611            assert!(
1612                !failed.state.reason.to_error_message().is_empty(),
1613                "Expected failed request to have failure reason"
1614            );
1615        } else {
1616            panic!(
1617                "Expected request to be in Failed state, got {:?}",
1618                results.first()
1619            );
1620        }
1621    }
1622
1623    #[sqlx::test]
1624    async fn test_retry_stops_at_deadline_when_no_limits_set(pool: sqlx::PgPool) {
1625        // Test that when neither max_retries nor stop_before_deadline_ms is set,
1626        // retries stop exactly at the deadline (no buffer)
1627        let http_client = Arc::new(MockHttpClient::new());
1628
1629        // All requests will fail
1630        for _ in 0..20 {
1631            http_client.add_response(
1632                "POST /v1/test",
1633                Ok(HttpResponse {
1634                    status: 500,
1635                    body: r#"{"error":"server error"}"#.to_string(),
1636                }),
1637            );
1638        }
1639
1640        // No max_retries, no stop_before_deadline_ms
1641        let config = DaemonConfig {
1642            claim_batch_size: 10,
1643            claim_interval_ms: 10,
1644            default_model_concurrency: 10,
1645            model_concurrency_limits: Arc::new(dashmap::DashMap::new()),
1646            model_escalations: Arc::new(dashmap::DashMap::new()),
1647            max_retries: None,             // No retry limit
1648            stop_before_deadline_ms: None, // No buffer - should retry until deadline
1649            backoff_ms: 50,
1650            backoff_factor: 2,
1651            max_backoff_ms: 200,
1652            timeout_ms: 5000,
1653            status_log_interval_ms: None,
1654            heartbeat_interval_ms: 10000,
1655            should_retry: Arc::new(default_should_retry),
1656            claim_timeout_ms: 60000,
1657            processing_timeout_ms: 600000,
1658            unclaim_batch_size: 100,
1659            batch_metadata_fields: vec![],
1660            cancellation_poll_interval_ms: 100,
1661        };
1662
1663        let manager = Arc::new(
1664            PostgresRequestManager::with_client(
1665                TestDbPools::new(pool.clone()).await.unwrap(),
1666                http_client.clone(),
1667            )
1668            .with_config(config),
1669        );
1670
1671        // Create a batch with a 2 second completion window
1672        let file_id = manager
1673            .create_file(
1674                "test-file".to_string(),
1675                Some("Test no limits retry".to_string()),
1676                vec![crate::RequestTemplateInput {
1677                    custom_id: None,
1678                    endpoint: "https://api.example.com".to_string(),
1679                    method: "POST".to_string(),
1680                    path: "/v1/test".to_string(),
1681                    body: r#"{"prompt":"test"}"#.to_string(),
1682                    model: "test-model".to_string(),
1683                    api_key: "test-key".to_string(),
1684                }],
1685            )
1686            .await
1687            .expect("Failed to create file");
1688
1689        let batch = manager
1690            .create_batch(crate::batch::BatchInput {
1691                file_id,
1692                endpoint: "/v1/chat/completions".to_string(),
1693                completion_window: "2s".to_string(),
1694                metadata: None,
1695                created_by: None,
1696            })
1697            .await
1698            .expect("Failed to create batch");
1699
1700        let requests = manager
1701            .get_batch_requests(batch.id)
1702            .await
1703            .expect("Failed to get batch requests");
1704        let request_id = requests[0].id();
1705
1706        // Start the daemon
1707        let shutdown_token = tokio_util::sync::CancellationToken::new();
1708        manager
1709            .clone()
1710            .run(shutdown_token.clone())
1711            .expect("Failed to start daemon");
1712
1713        // Wait for the deadline to pass
1714        tokio::time::sleep(Duration::from_secs(3)).await;
1715
1716        // Check the request state
1717        let results = manager
1718            .get_requests(vec![request_id])
1719            .await
1720            .expect("Failed to get request");
1721
1722        shutdown_token.cancel();
1723
1724        if let Some(Ok(crate::AnyRequest::Failed(failed))) = results.first() {
1725            // Calculate expected retry attempts with NO buffer:
1726            // - Completion window: 2000ms
1727            // - Buffer: 0ms (none set)
1728            // - Effective deadline: 2000ms
1729            // - Backoff sequence: 50ms, 100ms, 200ms, 200ms, 200ms...
1730            // - Timeline:
1731            //   - Initial attempt: t=0ms (attempt 0)
1732            //   - Retry 1: t=50ms (attempt 1)
1733            //   - Retry 2: t=150ms (attempt 2)
1734            //   - Retry 3: t=350ms (attempt 3)
1735            //   - Retry 4: t=550ms (attempt 4)
1736            //   - Retry 5: t=750ms (attempt 5)
1737            //   - Retry 6: t=950ms (attempt 6)
1738            //   - Retry 7: t=1150ms (attempt 7)
1739            //   - Retry 8: t=1350ms (attempt 8)
1740            //   - Retry 9: t=1550ms (attempt 9)
1741            //   - Retry 10: t=1750ms (attempt 10)
1742            //   - Retry 11: t=1950ms (attempt 11)
1743            //   - Next would be t=2150ms - EXCEEDS 2000ms deadline
1744            // Expected: ~11 retry attempts (12 total including initial)
1745            // In reality, we will see <11 due to DB calls and CPU overhead in making requests
1746
1747            let retry_count = failed.state.retry_attempt;
1748            let call_count = http_client.call_count();
1749
1750            // 1. Verify we retried more than the buffered case (which stopped at ~8)
1751            //    but still stopped before too many attempts
1752            // Allow 6-12 attempts to account for timing variations with CI slower CI CPUs,
1753            // parallel test execution overhead, and query overhead from batch metadata fields
1754            assert!(
1755                (6..12).contains(&retry_count),
1756                "Expected 6-12 retry attempts (should retry until deadline with no buffer), got {}",
1757                retry_count
1758            );
1759
1760            // 2. Verify HTTP call count matches retry attempts (1 initial + N retries)
1761            assert_eq!(
1762                call_count,
1763                (retry_count + 1) as usize,
1764                "Expected call count to match retry attempts + 1 initial attempt, got {} calls for {} retry attempts",
1765                call_count,
1766                retry_count
1767            );
1768
1769            // 3. Verify the request has error details from the last attempt
1770            assert!(
1771                !failed.state.reason.to_error_message().is_empty(),
1772                "Expected failed request to have failure reason"
1773            );
1774        } else {
1775            panic!(
1776                "Expected request to be in Failed state, got {:?}",
1777                results.first()
1778            );
1779        }
1780    }
1781
1782    #[sqlx::test]
1783    #[test_log::test]
1784    async fn test_batch_metadata_headers_passed_through(pool: sqlx::PgPool) {
1785        let http_client = crate::http::MockHttpClient::new();
1786        http_client.add_response(
1787            "POST /v1/chat/completions",
1788            Ok(crate::http::HttpResponse {
1789                status: 200,
1790                body: r#"{"id":"chatcmpl-123","choices":[{"message":{"content":"test"}}]}"#
1791                    .to_string(),
1792            }),
1793        );
1794
1795        let config = DaemonConfig {
1796            claim_batch_size: 10,
1797            claim_interval_ms: 10,
1798            default_model_concurrency: 10,
1799            model_concurrency_limits: Arc::new(dashmap::DashMap::new()),
1800            model_escalations: Arc::new(dashmap::DashMap::new()),
1801            max_retries: Some(3),
1802            stop_before_deadline_ms: None,
1803            backoff_ms: 100,
1804            backoff_factor: 2,
1805            max_backoff_ms: 1000,
1806            timeout_ms: 5000,
1807            status_log_interval_ms: None,
1808            heartbeat_interval_ms: 10000,
1809            should_retry: Arc::new(default_should_retry),
1810            claim_timeout_ms: 60000,
1811            processing_timeout_ms: 600000,
1812            unclaim_batch_size: 100,
1813            batch_metadata_fields: vec![
1814                "id".to_string(),
1815                "endpoint".to_string(),
1816                "created_at".to_string(),
1817                "completion_window".to_string(),
1818            ],
1819            cancellation_poll_interval_ms: 100,
1820        };
1821
1822        let manager = Arc::new(
1823            PostgresRequestManager::with_client(
1824                TestDbPools::new(pool.clone()).await.unwrap(),
1825                Arc::new(http_client.clone()),
1826            )
1827            .with_config(config),
1828        );
1829
1830        // Create a batch
1831        let file_id = manager
1832            .create_file(
1833                "test-file".to_string(),
1834                Some("Test batch metadata".to_string()),
1835                vec![crate::RequestTemplateInput {
1836                    custom_id: None,
1837                    endpoint: "https://api.example.com".to_string(),
1838                    method: "POST".to_string(),
1839                    path: "/v1/chat/completions".to_string(),
1840                    body: r#"{"prompt":"test"}"#.to_string(),
1841                    model: "test-model".to_string(),
1842                    api_key: "test-key".to_string(),
1843                }],
1844            )
1845            .await
1846            .expect("Failed to create file");
1847
1848        let batch = manager
1849            .create_batch(crate::batch::BatchInput {
1850                file_id,
1851                endpoint: "/v1/chat/completions".to_string(),
1852                completion_window: "24h".to_string(),
1853                metadata: None,
1854                created_by: Some("test-user".to_string()),
1855            })
1856            .await
1857            .expect("Failed to create batch");
1858
1859        let requests = manager
1860            .get_batch_requests(batch.id)
1861            .await
1862            .expect("Failed to get batch requests");
1863        let request_id = requests[0].id();
1864
1865        // Start the daemon
1866        let shutdown_token = tokio_util::sync::CancellationToken::new();
1867        manager
1868            .clone()
1869            .run(shutdown_token.clone())
1870            .expect("Failed to start daemon");
1871
1872        // Wait for request to be processed
1873        tokio::time::sleep(Duration::from_millis(500)).await;
1874
1875        shutdown_token.cancel();
1876
1877        // Wait a bit for shutdown
1878        tokio::time::sleep(Duration::from_millis(200)).await;
1879
1880        // Verify the request was completed
1881        let results = manager
1882            .get_requests(vec![request_id])
1883            .await
1884            .expect("Failed to get request");
1885
1886        assert_eq!(results.len(), 1);
1887        assert!(
1888            matches!(results[0], Ok(crate::AnyRequest::Completed(_))),
1889            "Expected request to be completed"
1890        );
1891
1892        // Verify batch metadata was passed to HTTP client
1893        let calls = http_client.get_calls();
1894        assert_eq!(calls.len(), 1, "Expected exactly one HTTP call");
1895
1896        let call = &calls[0];
1897        assert_eq!(
1898            call.batch_metadata.len(),
1899            4,
1900            "Expected 4 batch metadata fields"
1901        );
1902
1903        // Verify each configured field was passed through
1904        assert!(
1905            call.batch_metadata.contains_key("id"),
1906            "Expected batch id in metadata"
1907        );
1908        assert!(
1909            call.batch_metadata.contains_key("endpoint"),
1910            "Expected batch endpoint in metadata"
1911        );
1912        assert!(
1913            call.batch_metadata.contains_key("created_at"),
1914            "Expected batch created_at in metadata"
1915        );
1916        assert!(
1917            call.batch_metadata.contains_key("completion_window"),
1918            "Expected batch completion_window in metadata"
1919        );
1920
1921        // Verify values are correct
1922        assert_eq!(
1923            call.batch_metadata.get("endpoint"),
1924            Some(&"/v1/chat/completions".to_string()),
1925            "Batch endpoint should match"
1926        );
1927        assert_eq!(
1928            call.batch_metadata.get("completion_window"),
1929            Some(&"24h".to_string()),
1930            "Completion window should match"
1931        );
1932    }
1933
1934    #[sqlx::test]
1935    #[test_log::test]
1936    async fn test_batch_metadata_extracts_fields_from_json_metadata(pool: sqlx::PgPool) {
1937        let http_client = crate::http::MockHttpClient::new();
1938        http_client.add_response(
1939            "POST /v1/chat/completions",
1940            Ok(crate::http::HttpResponse {
1941                status: 200,
1942                body: r#"{"id":"chatcmpl-123","choices":[{"message":{"content":"test"}}]}"#
1943                    .to_string(),
1944            }),
1945        );
1946
1947        // Configure batch_metadata_fields to include "request_source" which is stored
1948        // inside the metadata JSON, not as a direct column
1949        let config = DaemonConfig {
1950            claim_batch_size: 10,
1951            claim_interval_ms: 10,
1952            default_model_concurrency: 10,
1953            model_concurrency_limits: Arc::new(dashmap::DashMap::new()),
1954            model_escalations: Arc::new(dashmap::DashMap::new()),
1955            max_retries: Some(3),
1956            stop_before_deadline_ms: None,
1957            backoff_ms: 100,
1958            backoff_factor: 2,
1959            max_backoff_ms: 1000,
1960            timeout_ms: 5000,
1961            status_log_interval_ms: None,
1962            heartbeat_interval_ms: 10000,
1963            should_retry: Arc::new(default_should_retry),
1964            claim_timeout_ms: 60000,
1965            processing_timeout_ms: 600000,
1966            unclaim_batch_size: 100,
1967            batch_metadata_fields: vec![
1968                "id".to_string(),
1969                "endpoint".to_string(),
1970                "completion_window".to_string(),
1971                "request_source".to_string(), // This comes from metadata JSON
1972            ],
1973            cancellation_poll_interval_ms: 100,
1974        };
1975
1976        let manager = Arc::new(
1977            PostgresRequestManager::with_client(
1978                TestDbPools::new(pool.clone()).await.unwrap(),
1979                Arc::new(http_client.clone()),
1980            )
1981            .with_config(config),
1982        );
1983
1984        // Create a batch with metadata containing request_source
1985        let file_id = manager
1986            .create_file(
1987                "test-file".to_string(),
1988                Some("Test metadata JSON extraction".to_string()),
1989                vec![crate::RequestTemplateInput {
1990                    custom_id: None,
1991                    endpoint: "https://api.example.com".to_string(),
1992                    method: "POST".to_string(),
1993                    path: "/v1/chat/completions".to_string(),
1994                    body: r#"{"prompt":"test"}"#.to_string(),
1995                    model: "test-model".to_string(),
1996                    api_key: "test-key".to_string(),
1997                }],
1998            )
1999            .await
2000            .expect("Failed to create file");
2001
2002        let batch = manager
2003            .create_batch(crate::batch::BatchInput {
2004                file_id,
2005                endpoint: "/v1/chat/completions".to_string(),
2006                completion_window: "24h".to_string(),
2007                metadata: Some(serde_json::json!({
2008                    "request_source": "api",
2009                    "created_by": "user-123"
2010                })),
2011                created_by: Some("test-user".to_string()),
2012            })
2013            .await
2014            .expect("Failed to create batch");
2015
2016        let requests = manager
2017            .get_batch_requests(batch.id)
2018            .await
2019            .expect("Failed to get batch requests");
2020        let request_id = requests[0].id();
2021
2022        // Start the daemon
2023        let shutdown_token = tokio_util::sync::CancellationToken::new();
2024        manager
2025            .clone()
2026            .run(shutdown_token.clone())
2027            .expect("Failed to start daemon");
2028
2029        // Wait for request to be processed
2030        tokio::time::sleep(Duration::from_millis(500)).await;
2031
2032        shutdown_token.cancel();
2033
2034        // Wait a bit for shutdown
2035        tokio::time::sleep(Duration::from_millis(200)).await;
2036
2037        // Verify the request was completed
2038        let results = manager
2039            .get_requests(vec![request_id])
2040            .await
2041            .expect("Failed to get request");
2042
2043        assert_eq!(results.len(), 1);
2044        assert!(
2045            matches!(results[0], Ok(crate::AnyRequest::Completed(_))),
2046            "Expected request to be completed"
2047        );
2048
2049        // Verify batch metadata was passed to HTTP client
2050        let calls = http_client.get_calls();
2051        assert_eq!(calls.len(), 1, "Expected exactly one HTTP call");
2052
2053        let call = &calls[0];
2054        assert_eq!(
2055            call.batch_metadata.len(),
2056            4,
2057            "Expected 4 batch metadata fields (id, endpoint, completion_window, request_source)"
2058        );
2059
2060        // Verify direct column fields
2061        assert!(
2062            call.batch_metadata.contains_key("id"),
2063            "Expected batch id in metadata"
2064        );
2065        assert_eq!(
2066            call.batch_metadata.get("endpoint"),
2067            Some(&"/v1/chat/completions".to_string()),
2068            "Batch endpoint should match"
2069        );
2070
2071        // Verify request_source was extracted from metadata JSON
2072        assert_eq!(
2073            call.batch_metadata.get("request_source"),
2074            Some(&"api".to_string()),
2075            "request_source should be extracted from metadata JSON"
2076        );
2077    }
2078}