use super::super::types::*;
use super::core::BatchProcessor;
use crate::core::models::openai::{ChatCompletionRequest, EmbeddingRequest};
use crate::utils::error::gateway_error::{GatewayError, Result};
use chrono::Utc;
use std::collections::HashMap;
use tracing::{debug, info};
use uuid::Uuid;
impl BatchProcessor {
pub(super) async fn process_batch(&self, batch_id: String) -> Result<()> {
info!("Processing batch: {}", batch_id);
self.update_batch_status(&batch_id, BatchStatus::InProgress)
.await?;
let batch_request = self
.database
.get_batch_request(&batch_id)
.await?
.ok_or_else(|| GatewayError::NotFound("Batch request not found".to_string()))?;
let mut results = Vec::with_capacity(batch_request.requests.len());
let mut completed = 0;
let mut failed = 0;
for item in &batch_request.requests {
if self.is_batch_cancelled(&batch_id).await? {
break;
}
match self
.process_batch_item(item, &batch_request.batch_type)
.await
{
Ok(result) => {
results.push(result);
completed += 1;
}
Err(e) => {
let error_result = BatchResult {
custom_id: item.custom_id.clone(),
response: None,
error: Some(BatchError {
code: "processing_error".to_string(),
message: e.to_string(),
details: None,
}),
};
results.push(error_result);
failed += 1;
}
}
if (completed + failed) % 100 == 0 {
self.update_batch_progress(&batch_id, completed, failed)
.await?;
}
}
{
let mut storage = self.results_storage.write().await;
storage.insert(batch_id.clone(), results.clone());
}
let json_results: Vec<serde_json::Value> = results
.iter()
.map(|r| serde_json::to_value(r).unwrap_or_default())
.collect();
self.database
.store_batch_results(&batch_id, &json_results)
.await?;
let final_status = if self.is_batch_cancelled(&batch_id).await? {
BatchStatus::Cancelled
} else {
BatchStatus::Completed
};
self.update_batch_status(&batch_id, final_status).await?;
self.update_batch_progress(&batch_id, completed, failed)
.await?;
self.mark_batch_completed(&batch_id).await?;
info!(
"Batch processing completed: {} (completed: {}, failed: {})",
batch_id, completed, failed
);
Ok(())
}
pub(super) async fn process_batch_item(
&self,
item: &BatchItem,
batch_type: &BatchType,
) -> Result<BatchResult> {
debug!("Processing batch item: {}", item.custom_id);
match batch_type {
BatchType::ChatCompletion => {
let request: ChatCompletionRequest = serde_json::from_value(item.body.clone())
.map_err(|e| {
GatewayError::BadRequest(format!("Invalid request body: {}", e))
})?;
let response = BatchHttpResponse {
status_code: 200,
headers: HashMap::new(),
body: serde_json::json!({
"id": format!("chatcmpl-batch-{}", Uuid::new_v4()),
"object": "chat.completion",
"created": Utc::now().timestamp(),
"model": request.model,
"choices": [{
"index": 0,
"message": {
"role": "assistant",
"content": "This is a batch processed response."
},
"finish_reason": "stop"
}],
"usage": {
"prompt_tokens": 10,
"completion_tokens": 8,
"total_tokens": 18
}
}),
};
Ok(BatchResult {
custom_id: item.custom_id.clone(),
response: Some(response),
error: None,
})
}
BatchType::Embedding => {
let request: EmbeddingRequest =
serde_json::from_value(item.body.clone()).map_err(|e| {
GatewayError::BadRequest(format!("Invalid request body: {}", e))
})?;
let response = BatchHttpResponse {
status_code: 200,
headers: HashMap::new(),
body: serde_json::json!({
"object": "list",
"data": [{
"object": "embedding",
"embedding": vec![0.1; 1536], "index": 0
}],
"model": request.model,
"usage": {
"prompt_tokens": 5,
"total_tokens": 5
}
}),
};
Ok(BatchResult {
custom_id: item.custom_id.clone(),
response: Some(response),
error: None,
})
}
_ => Err(GatewayError::BadRequest(
"Unsupported batch type".to_string(),
)),
}
}
pub(super) fn get_endpoint_for_batch_type(&self, batch_type: &BatchType) -> String {
match batch_type {
BatchType::ChatCompletion => "/v1/chat/completions".to_string(),
BatchType::Embedding => "/v1/embeddings".to_string(),
BatchType::ImageGeneration => "/v1/images/generations".to_string(),
BatchType::Custom(endpoint) => endpoint.clone(),
}
}
}