use std::sync::Arc;
use std::time::{Duration, Instant};
use fusillade::{PersistCompletedRealtimeInput, PostgresRequestManager, ReqwestHttpClient, Storage};
use metrics::{counter, gauge, histogram};
use sqlx_pool_router::PoolProvider;
use tokio::sync::mpsc;
use tokio_util::sync::CancellationToken;
use tracing::{Instrument, debug, error, info, info_span, warn};
use uuid::Uuid;
const CHANNEL_BUFFER_SIZE: usize = 10_000;
const DEFAULT_MAX_RETRIES: u32 = 3;
const DEFAULT_RETRY_BASE_DELAY_MS: u64 = 100;
#[derive(Debug, Clone)]
pub struct RawCompletedRequest {
pub request_id: Uuid,
pub status_code: u16,
pub response_body: String,
pub request_body: String,
pub model: String,
pub endpoint: String,
pub api_key: String,
pub created_by: String,
}
pub type RequestsWriterSender = mpsc::Sender<RawCompletedRequest>;
pub struct RequestsWriter<P: PoolProvider + Clone + Send + Sync + 'static> {
request_manager: Arc<PostgresRequestManager<P, ReqwestHttpClient>>,
receiver: mpsc::Receiver<RawCompletedRequest>,
batch_size: usize,
max_retries: u32,
retry_base_delay: Duration,
}
impl<P: PoolProvider + Clone + Send + Sync + 'static> RequestsWriter<P> {
pub fn new(request_manager: Arc<PostgresRequestManager<P, ReqwestHttpClient>>, batch_size: usize) -> (Self, RequestsWriterSender) {
let (sender, receiver) = mpsc::channel(CHANNEL_BUFFER_SIZE);
let writer = Self {
request_manager,
receiver,
batch_size: batch_size.max(1),
max_retries: DEFAULT_MAX_RETRIES,
retry_base_delay: Duration::from_millis(DEFAULT_RETRY_BASE_DELAY_MS),
};
(writer, sender)
}
pub async fn run(mut self, shutdown_token: CancellationToken) {
info!(
batch_size = self.batch_size,
max_retries = self.max_retries,
"Responses writer started"
);
let mut buffer: Vec<RawCompletedRequest> = Vec::with_capacity(self.batch_size);
loop {
tokio::select! {
biased;
_ = shutdown_token.cancelled() => {
let pre_buffered = buffer.len();
info!(pre_buffered, "Shutdown signal received, draining responses writer channel");
self.receiver.close();
let mut drained = 0usize;
while let Some(record) = self.receiver.recv().await {
buffer.push(record);
drained += 1;
if buffer.len() >= self.batch_size {
self.flush_batch(&mut buffer).await;
}
}
if !buffer.is_empty() {
self.flush_batch(&mut buffer).await;
}
let total_flushed = pre_buffered + drained;
counter!("dwctl_requests_writer_shutdown_records_flushed").increment(total_flushed as u64);
info!(
pre_buffered,
drained,
total_flushed,
"Responses writer shutdown complete"
);
break;
}
maybe_record = self.receiver.recv() => {
match maybe_record {
Some(record) => buffer.push(record),
None => {
info!("Responses writer channel closed, shutting down");
if !buffer.is_empty() {
self.flush_batch(&mut buffer).await;
}
break;
}
}
}
}
while buffer.len() < self.batch_size {
match self.receiver.try_recv() {
Ok(record) => buffer.push(record),
Err(_) => break,
}
}
gauge!("dwctl_requests_writer_channel_depth").set(self.receiver.len() as f64);
self.flush_batch(&mut buffer).await;
}
}
async fn flush_batch(&self, buffer: &mut Vec<RawCompletedRequest>) {
if buffer.is_empty() {
return;
}
let batch_size = buffer.len();
let span = info_span!("dwctl.flush_responses_batch", batch_size);
async {
let start = Instant::now();
histogram!("dwctl_requests_writer_flush_size").record(batch_size as f64);
let inputs: Vec<PersistCompletedRealtimeInput> = buffer
.iter()
.map(|record| PersistCompletedRealtimeInput {
request_id: record.request_id,
response_body: record.response_body.clone(),
status_code: record.status_code,
request_body: record.request_body.clone(),
model: record.model.clone(),
endpoint: String::new(),
method: "POST".to_string(),
path: record.endpoint.clone(),
api_key: record.api_key.clone(),
created_by: record.created_by.clone(),
})
.collect();
let mut last_error = None;
for attempt in 0..=self.max_retries {
match self.request_manager.persist_completed_realtime_batch(&inputs).await {
Ok(()) => {
if attempt > 0 {
debug!(attempt, batch_size, "Responses batch flush succeeded after retry");
counter!("dwctl_requests_writer_retries_total", "outcome" => "success").increment(1);
}
last_error = None;
break;
}
Err(e) => {
last_error = Some(e);
if attempt < self.max_retries {
let delay = self.retry_base_delay * 2u32.pow(attempt);
warn!(
error = %last_error.as_ref().unwrap(),
attempt = attempt + 1,
max_retries = self.max_retries,
delay_ms = delay.as_millis() as u64,
batch_size,
"Responses batch flush failed, retrying"
);
counter!("dwctl_requests_writer_retries_total", "outcome" => "retry").increment(1);
tokio::time::sleep(delay).await;
}
}
}
}
if let Some(e) = last_error {
error!(
error = %e,
batch_size,
attempts = self.max_retries + 1,
"Failed to flush responses batch after all retries, dropping batch"
);
counter!("dwctl_requests_writer_flush_errors_total").increment(1);
buffer.clear();
return;
}
let duration = start.elapsed();
histogram!("dwctl_requests_writer_flush_duration_seconds").record(duration.as_secs_f64());
counter!("dwctl_requests_writer_records_total").increment(batch_size as u64);
debug!(batch_size, duration_ms = duration.as_millis() as u64, "Flushed responses batch");
buffer.clear();
}
.instrument(span)
.await;
}
}
#[cfg(test)]
mod tests {
use super::*;
use fusillade::{PostgresRequestManager, RequestId};
use sqlx_pool_router::TestDbPools;
use std::time::Duration;
use tokio::time::timeout;
use tokio_util::sync::CancellationToken;
use uuid::Uuid;
async fn build_writer(
pool: sqlx::PgPool,
) -> (
RequestsWriter<TestDbPools>,
RequestsWriterSender,
Arc<PostgresRequestManager<TestDbPools, ReqwestHttpClient>>,
) {
let fusillade_pool = crate::test::utils::setup_fusillade_pool(&pool).await;
let pools = TestDbPools::new(fusillade_pool).await.unwrap();
let http_client = Arc::new(ReqwestHttpClient::default());
let manager = Arc::new(PostgresRequestManager::with_client(pools, http_client));
let (writer, sender) = RequestsWriter::new(manager.clone(), 8);
(writer, sender, manager)
}
async fn wait_until_completed(
manager: &PostgresRequestManager<TestDbPools, ReqwestHttpClient>,
request_id: Uuid,
timeout_secs: u64,
) -> fusillade::RequestDetail {
let deadline = std::time::Instant::now() + Duration::from_secs(timeout_secs);
loop {
match Storage::get_request_detail(manager, RequestId(request_id)).await {
Ok(detail) if detail.status == "completed" => return detail,
Ok(_) | Err(fusillade::FusilladeError::RequestNotFound(_)) => {}
Err(e) => panic!("get_request_detail failed: {e}"),
}
if std::time::Instant::now() >= deadline {
panic!("timed out waiting for request {request_id} to complete");
}
tokio::time::sleep(Duration::from_millis(20)).await;
}
}
#[sqlx::test]
async fn test_writer_persists_completed_record(pool: sqlx::PgPool) {
let (writer, sender, manager) = build_writer(pool).await;
let shutdown = CancellationToken::new();
let handle = tokio::spawn(writer.run(shutdown.clone()));
let request_id = Uuid::new_v4();
sender
.send(RawCompletedRequest {
request_id,
status_code: 200,
response_body: r#"{"output":"done"}"#.to_string(),
request_body: r#"{"input":"hi"}"#.to_string(),
model: "gpt-4".to_string(),
endpoint: "/v1/responses".to_string(),
api_key: String::new(),
created_by: "user-test".to_string(),
})
.await
.expect("send should succeed");
let detail = wait_until_completed(&manager, request_id, 5).await;
assert_eq!(detail.status, "completed");
assert_eq!(detail.service_tier, Some("priority".to_string()));
assert_eq!(detail.response_body, Some(r#"{"output":"done"}"#.to_string()));
assert_eq!(detail.response_status, Some(200));
shutdown.cancel();
timeout(Duration::from_secs(5), handle)
.await
.expect("writer should shut down within 5s")
.expect("writer task should not panic");
}
#[sqlx::test]
async fn test_writer_batches_multiple_records_in_one_flush(pool: sqlx::PgPool) {
let (writer, sender, manager) = build_writer(pool).await;
let shutdown = CancellationToken::new();
let handle = tokio::spawn(writer.run(shutdown.clone()));
let request_ids: Vec<Uuid> = (0..5).map(|_| Uuid::new_v4()).collect();
for id in &request_ids {
sender
.send(RawCompletedRequest {
request_id: *id,
status_code: 200,
response_body: format!(r#"{{"output":"{id}"}}"#),
request_body: r#"{"input":"hi"}"#.to_string(),
model: "gpt-4".to_string(),
endpoint: "/v1/responses".to_string(),
api_key: String::new(),
created_by: "user-test".to_string(),
})
.await
.expect("send should succeed");
}
for id in &request_ids {
let detail = wait_until_completed(&manager, *id, 5).await;
assert_eq!(detail.status, "completed");
}
shutdown.cancel();
timeout(Duration::from_secs(5), handle)
.await
.expect("writer should shut down within 5s")
.expect("writer task should not panic");
}
#[sqlx::test]
async fn test_writer_drains_channel_on_shutdown(pool: sqlx::PgPool) {
let (writer, sender, manager) = build_writer(pool).await;
let shutdown = CancellationToken::new();
let handle = tokio::spawn(writer.run(shutdown.clone()));
let request_id = Uuid::new_v4();
sender
.send(RawCompletedRequest {
request_id,
status_code: 200,
response_body: r#"{"output":"shutdown-test"}"#.to_string(),
request_body: r#"{"input":"hi"}"#.to_string(),
model: "gpt-4".to_string(),
endpoint: "/v1/responses".to_string(),
api_key: String::new(),
created_by: "user-test".to_string(),
})
.await
.expect("send should succeed");
shutdown.cancel();
timeout(Duration::from_secs(5), handle)
.await
.expect("writer should shut down within 5s")
.expect("writer task should not panic");
let detail = Storage::get_request_detail(&*manager, RequestId(request_id))
.await
.expect("request should exist after drain");
assert_eq!(detail.status, "completed");
assert_eq!(detail.response_body, Some(r#"{"output":"shutdown-test"}"#.to_string()));
drop(sender);
}
}