use super::super::types::*;
use crate::storage::database::Database;
use crate::utils::error::gateway_error::{GatewayError, Result};
use chrono::Utc;
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::RwLock;
use tracing::{error, info};
pub struct BatchProcessor {
pub(super) database: Arc<Database>,
pub(super) active_batches: Arc<RwLock<HashMap<String, BatchResponse>>>,
pub(super) results_storage: Arc<RwLock<HashMap<String, Vec<BatchResult>>>>,
}
impl BatchProcessor {
pub fn new(database: Arc<Database>) -> Self {
Self {
database,
active_batches: Arc::new(RwLock::new(HashMap::new())),
results_storage: Arc::new(RwLock::new(HashMap::new())),
}
}
pub async fn create_batch(&self, request: BatchRequest) -> Result<BatchResponse> {
info!("Creating batch: {}", request.batch_id);
self.validate_batch_request(&request).await?;
let batch_response = BatchResponse {
id: request.batch_id.clone(),
object: "batch".to_string(),
endpoint: self.get_endpoint_for_batch_type(&request.batch_type),
status: BatchStatus::Validating,
created_at: Utc::now(),
completed_at: None,
expires_at: Some(
Utc::now()
+ chrono::Duration::hours(request.completion_window.unwrap_or(24) as i64),
),
input_file_id: None,
output_file_id: None,
error_file_id: None,
request_counts: BatchRequestCounts {
total: request.requests.len() as i32,
completed: 0,
failed: 0,
},
metadata: Some(
serde_json::to_value(request.metadata.clone()).unwrap_or(serde_json::Value::Null),
),
completion_window: format!("{}h", request.completion_window.unwrap_or(24)),
in_progress_at: None,
finalizing_at: None,
failed_at: None,
expired_at: None,
cancelling_at: None,
cancelled_at: None,
};
self.database.create_batch(&request).await?;
{
let mut active = self.active_batches.write().await;
active.insert(request.batch_id.clone(), batch_response.clone());
}
let processor = self.clone();
let batch_id = request.batch_id.clone();
tokio::spawn(async move {
if let Err(e) = processor.process_batch(batch_id).await {
error!("Batch processing failed: {}", e);
}
});
Ok(batch_response)
}
pub async fn get_batch(&self, batch_id: &str) -> Result<Option<BatchResponse>> {
{
let active = self.active_batches.read().await;
if let Some(batch) = active.get(batch_id) {
return Ok(Some(batch.clone()));
}
}
if let Some(batch_request) = self.database.get_batch_request(batch_id).await? {
let now = chrono::Utc::now();
let batch_response = BatchResponse {
id: batch_request.batch_id.clone(),
object: "batch".to_string(),
endpoint: "/v1/chat/completions".to_string(),
input_file_id: Some(batch_request.batch_id.clone()),
completion_window: "24h".to_string(),
status: BatchStatus::Completed,
output_file_id: Some(format!("{}_output", batch_request.batch_id)),
error_file_id: None,
created_at: now,
in_progress_at: Some(now),
expires_at: Some(now + chrono::Duration::try_days(1).unwrap_or_default()),
finalizing_at: None,
completed_at: Some(now),
failed_at: None,
expired_at: None,
cancelling_at: None,
cancelled_at: None,
request_counts: BatchRequestCounts {
total: batch_request.requests.len() as i32,
completed: batch_request.requests.len() as i32,
failed: 0,
},
metadata: Some(
serde_json::to_value(batch_request.metadata).unwrap_or(serde_json::Value::Null),
),
};
return Ok(Some(batch_response));
}
Ok(None)
}
pub async fn cancel_batch(&self, batch_id: &str) -> Result<BatchResponse> {
info!("Cancelling batch: {}", batch_id);
let mut batch = self
.get_batch(batch_id)
.await?
.ok_or_else(|| GatewayError::NotFound("Batch not found".to_string()))?;
match batch.status {
BatchStatus::Validating | BatchStatus::InProgress => {
batch.status = BatchStatus::Cancelling;
{
let mut active = self.active_batches.write().await;
active.insert(batch_id.to_string(), batch.clone());
}
self.database
.update_batch_status(batch_id, &format!("{:?}", batch.status))
.await?;
Ok(batch)
}
_ => Err(GatewayError::BadRequest(
"Batch cannot be cancelled in current status".to_string(),
)),
}
}
pub async fn list_batches(
&self,
_user_id: &str,
limit: Option<u32>,
after: Option<&str>,
) -> Result<Vec<BatchResponse>> {
let records = self
.database
.list_batches(Some(limit.unwrap_or(20) as i32), after)
.await?;
let responses = records
.into_iter()
.map(|record| BatchResponse {
id: record.id,
object: record.object,
endpoint: record.endpoint,
status: record.status,
created_at: record.created_at,
completed_at: record.completed_at,
expires_at: record.expires_at,
input_file_id: record.input_file_id,
output_file_id: record.output_file_id,
error_file_id: record.error_file_id,
request_counts: record.request_counts,
metadata: record.metadata,
completion_window: record.completion_window,
in_progress_at: record.in_progress_at,
finalizing_at: record.finalizing_at,
failed_at: record.failed_at,
expired_at: record.expired_at,
cancelling_at: record.cancelling_at,
cancelled_at: record.cancelled_at,
})
.collect();
Ok(responses)
}
pub async fn get_batch_results(&self, batch_id: &str) -> Result<Vec<BatchResult>> {
{
let results = self.results_storage.read().await;
if let Some(batch_results) = results.get(batch_id) {
return Ok(batch_results.clone());
}
}
match self.database.get_batch_results(batch_id).await? {
Some(json_results) => {
let results: Vec<BatchResult> = json_results
.into_iter()
.filter_map(|v| serde_json::from_value(v).ok())
.collect();
Ok(results)
}
None => Ok(Vec::new()),
}
}
}
impl Clone for BatchProcessor {
fn clone(&self) -> Self {
Self {
database: self.database.clone(),
active_batches: self.active_batches.clone(),
results_storage: self.results_storage.clone(),
}
}
}