use crate::core::models::openai::{ChatCompletionRequest, EmbeddingRequest};
use crate::storage::database::Database;
use crate::utils::error::{GatewayError, Result};
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::RwLock;
use tracing::{debug, error, info};
use uuid::Uuid;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BatchRequest {
pub batch_id: String,
pub user_id: String,
pub batch_type: BatchType,
pub requests: Vec<BatchItem>,
pub metadata: HashMap<String, String>,
pub completion_window: Option<u32>,
pub webhook_url: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum BatchType {
ChatCompletion,
Embedding,
ImageGeneration,
Custom(String),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BatchItem {
pub custom_id: String,
pub method: String,
pub url: String,
pub body: serde_json::Value,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BatchRecord {
pub id: String,
pub object: String,
pub endpoint: String,
pub input_file_id: Option<String>,
pub completion_window: String,
pub status: BatchStatus,
pub output_file_id: Option<String>,
pub error_file_id: Option<String>,
pub created_at: DateTime<Utc>,
pub in_progress_at: Option<DateTime<Utc>>,
pub expires_at: Option<DateTime<Utc>>,
pub finalizing_at: Option<DateTime<Utc>>,
pub completed_at: Option<DateTime<Utc>>,
pub failed_at: Option<DateTime<Utc>>,
pub expired_at: Option<DateTime<Utc>>,
pub cancelling_at: Option<DateTime<Utc>>,
pub cancelled_at: Option<DateTime<Utc>>,
pub request_counts: BatchRequestCounts,
pub metadata: Option<serde_json::Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BatchResponse {
pub id: String,
pub object: String,
pub endpoint: String,
pub status: BatchStatus,
pub created_at: DateTime<Utc>,
pub completed_at: Option<DateTime<Utc>>,
pub expires_at: Option<DateTime<Utc>>,
pub input_file_id: Option<String>,
pub output_file_id: Option<String>,
pub error_file_id: Option<String>,
pub request_counts: BatchRequestCounts,
pub metadata: Option<serde_json::Value>,
pub completion_window: String,
pub in_progress_at: Option<DateTime<Utc>>,
pub finalizing_at: Option<DateTime<Utc>>,
pub failed_at: Option<DateTime<Utc>>,
pub expired_at: Option<DateTime<Utc>>,
pub cancelling_at: Option<DateTime<Utc>>,
pub cancelled_at: Option<DateTime<Utc>>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub enum BatchStatus {
Validating,
Failed,
InProgress,
Finalizing,
Completed,
Expired,
Cancelling,
Cancelled,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BatchRequestCounts {
pub total: i32,
pub completed: i32,
pub failed: i32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BatchResult {
pub custom_id: String,
pub response: Option<BatchHttpResponse>,
pub error: Option<BatchError>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BatchHttpResponse {
pub status_code: u16,
pub headers: HashMap<String, String>,
pub body: serde_json::Value,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BatchError {
pub code: String,
pub message: String,
pub details: Option<serde_json::Value>,
}
pub struct BatchProcessor {
database: Arc<Database>,
active_batches: Arc<RwLock<HashMap<String, BatchResponse>>>,
results_storage: Arc<RwLock<HashMap<String, Vec<BatchResult>>>>,
}
impl BatchProcessor {
pub fn new(database: Arc<Database>) -> Self {
Self {
database,
active_batches: Arc::new(RwLock::new(HashMap::new())),
results_storage: Arc::new(RwLock::new(HashMap::new())),
}
}
pub async fn create_batch(&self, request: BatchRequest) -> Result<BatchResponse> {
info!("Creating batch: {}", request.batch_id);
self.validate_batch_request(&request).await?;
let batch_response = BatchResponse {
id: request.batch_id.clone(),
object: "batch".to_string(),
endpoint: self.get_endpoint_for_batch_type(&request.batch_type),
status: BatchStatus::Validating,
created_at: Utc::now(),
completed_at: None,
expires_at: Some(
Utc::now()
+ chrono::Duration::hours(request.completion_window.unwrap_or(24) as i64),
),
input_file_id: None,
output_file_id: None,
error_file_id: None,
request_counts: BatchRequestCounts {
total: request.requests.len() as i32,
completed: 0,
failed: 0,
},
metadata: Some(
serde_json::to_value(request.metadata.clone()).unwrap_or(serde_json::Value::Null),
),
completion_window: format!("{}h", request.completion_window.unwrap_or(24)),
in_progress_at: None,
finalizing_at: None,
failed_at: None,
expired_at: None,
cancelling_at: None,
cancelled_at: None,
};
self.database.create_batch(&request).await?;
{
let mut active = self.active_batches.write().await;
active.insert(request.batch_id.clone(), batch_response.clone());
}
let processor = self.clone();
let batch_id = request.batch_id.clone();
tokio::spawn(async move {
if let Err(e) = processor.process_batch(batch_id).await {
error!("Batch processing failed: {}", e);
}
});
Ok(batch_response)
}
pub async fn get_batch(&self, batch_id: &str) -> Result<Option<BatchResponse>> {
{
let active = self.active_batches.read().await;
if let Some(batch) = active.get(batch_id) {
return Ok(Some(batch.clone()));
}
}
Ok(None)
}
pub async fn cancel_batch(&self, batch_id: &str) -> Result<BatchResponse> {
info!("Cancelling batch: {}", batch_id);
let mut batch = self
.get_batch(batch_id)
.await?
.ok_or_else(|| GatewayError::NotFound("Batch not found".to_string()))?;
match batch.status {
BatchStatus::Validating | BatchStatus::InProgress => {
batch.status = BatchStatus::Cancelling;
{
let mut active = self.active_batches.write().await;
active.insert(batch_id.to_string(), batch.clone());
}
self.database
.update_batch_status(batch_id, &format!("{:?}", batch.status))
.await?;
Ok(batch)
}
_ => Err(GatewayError::InvalidRequest(
"Batch cannot be cancelled in current status".to_string(),
)),
}
}
pub async fn list_batches(
&self,
_user_id: &str,
limit: Option<u32>,
after: Option<&str>,
) -> Result<Vec<BatchResponse>> {
let records = self
.database
.list_batches(Some(limit.unwrap_or(20) as i32), after)
.await?;
let responses = records
.into_iter()
.map(|record| BatchResponse {
id: record.id,
object: record.object,
endpoint: record.endpoint,
status: record.status,
created_at: record.created_at,
completed_at: record.completed_at,
expires_at: record.expires_at,
input_file_id: record.input_file_id,
output_file_id: record.output_file_id,
error_file_id: record.error_file_id,
request_counts: record.request_counts,
metadata: record.metadata,
completion_window: record.completion_window,
in_progress_at: record.in_progress_at,
finalizing_at: record.finalizing_at,
failed_at: record.failed_at,
expired_at: record.expired_at,
cancelling_at: record.cancelling_at,
cancelled_at: record.cancelled_at,
})
.collect();
Ok(responses)
}
pub async fn get_batch_results(&self, batch_id: &str) -> Result<Vec<BatchResult>> {
{
let results = self.results_storage.read().await;
if let Some(batch_results) = results.get(batch_id) {
return Ok(batch_results.clone());
}
}
match self.database.get_batch_results(batch_id).await? {
Some(json_results) => {
let results: Vec<BatchResult> = json_results
.into_iter()
.filter_map(|v| serde_json::from_value(v).ok())
.collect();
Ok(results)
}
None => Ok(Vec::new()),
}
}
async fn validate_batch_request(&self, request: &BatchRequest) -> Result<()> {
if request.requests.len() > 50000 {
return Err(GatewayError::InvalidRequest(
"Batch size exceeds maximum limit of 50,000 requests".to_string(),
));
}
if request.requests.is_empty() {
return Err(GatewayError::InvalidRequest(
"Batch must contain at least one request".to_string(),
));
}
for item in &request.requests {
self.validate_batch_item(item, &request.batch_type).await?;
}
Ok(())
}
async fn validate_batch_item(&self, item: &BatchItem, batch_type: &BatchType) -> Result<()> {
if item.custom_id.is_empty() || item.custom_id.len() > 64 {
return Err(GatewayError::InvalidRequest(
"custom_id must be 1-64 characters".to_string(),
));
}
if item.method != "POST" {
return Err(GatewayError::InvalidRequest(
"Only POST method is supported for batch requests".to_string(),
));
}
match batch_type {
BatchType::ChatCompletion => {
if !item.url.contains("/chat/completions") {
return Err(GatewayError::InvalidRequest(
"URL must be /v1/chat/completions for chat completion batches".to_string(),
));
}
}
BatchType::Embedding => {
if !item.url.contains("/embeddings") {
return Err(GatewayError::InvalidRequest(
"URL must be /v1/embeddings for embedding batches".to_string(),
));
}
}
_ => {}
}
Ok(())
}
async fn process_batch(&self, batch_id: String) -> Result<()> {
info!("Processing batch: {}", batch_id);
self.update_batch_status(&batch_id, BatchStatus::InProgress)
.await?;
let batch_request = self
.database
.get_batch_request(&batch_id)
.await?
.ok_or_else(|| GatewayError::NotFound("Batch request not found".to_string()))?;
let mut results = Vec::new();
let mut completed = 0;
let mut failed = 0;
for item in &batch_request.requests {
if self.is_batch_cancelled(&batch_id).await? {
break;
}
match self
.process_batch_item(item, &batch_request.batch_type)
.await
{
Ok(result) => {
results.push(result);
completed += 1;
}
Err(e) => {
let error_result = BatchResult {
custom_id: item.custom_id.clone(),
response: None,
error: Some(BatchError {
code: "processing_error".to_string(),
message: e.to_string(),
details: None,
}),
};
results.push(error_result);
failed += 1;
}
}
if (completed + failed) % 100 == 0 {
self.update_batch_progress(&batch_id, completed, failed)
.await?;
}
}
{
let mut storage = self.results_storage.write().await;
storage.insert(batch_id.clone(), results.clone());
}
let json_results: Vec<serde_json::Value> = results
.iter()
.map(|r| serde_json::to_value(r).unwrap_or_default())
.collect();
self.database
.store_batch_results(&batch_id, &json_results)
.await?;
let final_status = if self.is_batch_cancelled(&batch_id).await? {
BatchStatus::Cancelled
} else {
BatchStatus::Completed
};
self.update_batch_status(&batch_id, final_status).await?;
self.update_batch_progress(&batch_id, completed, failed)
.await?;
self.mark_batch_completed(&batch_id).await?;
info!(
"Batch processing completed: {} (completed: {}, failed: {})",
batch_id, completed, failed
);
Ok(())
}
async fn process_batch_item(
&self,
item: &BatchItem,
batch_type: &BatchType,
) -> Result<BatchResult> {
debug!("Processing batch item: {}", item.custom_id);
match batch_type {
BatchType::ChatCompletion => {
let request: ChatCompletionRequest = serde_json::from_value(item.body.clone())
.map_err(|e| {
GatewayError::InvalidRequest(format!("Invalid request body: {}", e))
})?;
let response = BatchHttpResponse {
status_code: 200,
headers: HashMap::new(),
body: serde_json::json!({
"id": format!("chatcmpl-batch-{}", Uuid::new_v4()),
"object": "chat.completion",
"created": Utc::now().timestamp(),
"model": request.model,
"choices": [{
"index": 0,
"message": {
"role": "assistant",
"content": "This is a batch processed response."
},
"finish_reason": "stop"
}],
"usage": {
"prompt_tokens": 10,
"completion_tokens": 8,
"total_tokens": 18
}
}),
};
Ok(BatchResult {
custom_id: item.custom_id.clone(),
response: Some(response),
error: None,
})
}
BatchType::Embedding => {
let request: EmbeddingRequest =
serde_json::from_value(item.body.clone()).map_err(|e| {
GatewayError::InvalidRequest(format!("Invalid request body: {}", e))
})?;
let response = BatchHttpResponse {
status_code: 200,
headers: HashMap::new(),
body: serde_json::json!({
"object": "list",
"data": [{
"object": "embedding",
"embedding": vec![0.1; 1536], "index": 0
}],
"model": request.model,
"usage": {
"prompt_tokens": 5,
"total_tokens": 5
}
}),
};
Ok(BatchResult {
custom_id: item.custom_id.clone(),
response: Some(response),
error: None,
})
}
_ => Err(GatewayError::InvalidRequest(
"Unsupported batch type".to_string(),
)),
}
}
fn get_endpoint_for_batch_type(&self, batch_type: &BatchType) -> String {
match batch_type {
BatchType::ChatCompletion => "/v1/chat/completions".to_string(),
BatchType::Embedding => "/v1/embeddings".to_string(),
BatchType::ImageGeneration => "/v1/images/generations".to_string(),
BatchType::Custom(endpoint) => endpoint.clone(),
}
}
async fn update_batch_status(&self, batch_id: &str, status: BatchStatus) -> Result<()> {
{
let mut active = self.active_batches.write().await;
if let Some(batch) = active.get_mut(batch_id) {
batch.status = status.clone();
}
}
self.database
.update_batch_status(batch_id, &format!("{:?}", status))
.await
}
async fn update_batch_progress(
&self,
batch_id: &str,
completed: u32,
failed: u32,
) -> Result<()> {
{
let mut active = self.active_batches.write().await;
if let Some(batch) = active.get_mut(batch_id) {
batch.request_counts.completed = completed as i32;
batch.request_counts.failed = failed as i32;
}
}
self.database
.update_batch_progress(batch_id, completed as i32, failed as i32)
.await
}
async fn mark_batch_completed(&self, batch_id: &str) -> Result<()> {
let now = Utc::now();
{
let mut active = self.active_batches.write().await;
if let Some(batch) = active.get_mut(batch_id) {
batch.completed_at = Some(now);
}
}
self.database.mark_batch_completed(batch_id).await
}
async fn is_batch_cancelled(&self, batch_id: &str) -> Result<bool> {
let active = self.active_batches.read().await;
if let Some(batch) = active.get(batch_id) {
Ok(matches!(
batch.status,
BatchStatus::Cancelling | BatchStatus::Cancelled
))
} else {
Ok(false)
}
}
}
impl Clone for BatchProcessor {
fn clone(&self) -> Self {
Self {
database: self.database.clone(),
active_batches: self.active_batches.clone(),
results_storage: self.results_storage.clone(),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_batch_creation() {
}
#[test]
fn test_batch_status_transitions() {
assert_eq!(BatchStatus::Validating, BatchStatus::Validating);
assert_ne!(BatchStatus::Validating, BatchStatus::InProgress);
}
}