use crate::utils::error::gateway_error::{GatewayError, Result};
use sea_orm::*;
use tracing::{debug, warn};
use super::super::entities;
use super::types::SeaOrmDatabase;
impl SeaOrmDatabase {
pub async fn create_batch(&self, batch: &crate::core::batch::BatchRequest) -> Result<String> {
debug!("Creating batch: {}", batch.batch_id);
let active_model = entities::batch::ActiveModel {
id: Set(batch.batch_id.clone()),
object: Set("batch".to_string()),
endpoint: Set(match batch.batch_type {
crate::core::batch::BatchType::ChatCompletion => "/v1/chat/completions".to_string(),
crate::core::batch::BatchType::Embedding => "/v1/embeddings".to_string(),
crate::core::batch::BatchType::ImageGeneration => {
"/v1/images/generations".to_string()
}
crate::core::batch::BatchType::Custom(ref endpoint) => endpoint.clone(),
}),
input_file_id: Set(None),
completion_window: Set(format!("{}h", batch.completion_window.unwrap_or(24))),
status: Set("validating".to_string()),
output_file_id: Set(None),
error_file_id: Set(None),
created_at: Set(chrono::Utc::now().into()),
in_progress_at: Set(None),
finalizing_at: Set(None),
completed_at: Set(None),
failed_at: Set(None),
expired_at: Set(None),
cancelling_at: Set(None),
cancelled_at: Set(None),
request_counts_total: Set(Some(batch.requests.len() as i32)),
request_counts_completed: Set(Some(0)),
request_counts_failed: Set(Some(0)),
metadata: Set(Some(
serde_json::to_string(&batch.metadata).unwrap_or_default(),
)),
};
entities::Batch::insert(active_model)
.exec(&self.db)
.await
.map_err(GatewayError::from)?;
Ok(batch.batch_id.clone())
}
pub async fn update_batch_status(&self, batch_id: &str, status: &str) -> Result<()> {
debug!("Updating batch status: {} -> {}", batch_id, status);
let txn = self.db.begin().await.map_err(GatewayError::from)?;
let batch_model = entities::Batch::find_by_id(batch_id)
.one(&txn)
.await
.map_err(GatewayError::from)?
.ok_or_else(|| GatewayError::NotFound("Batch not found".to_string()))?;
let mut active_model: entities::batch::ActiveModel = batch_model.into();
active_model.status = Set(status.to_string());
let now = chrono::Utc::now().into();
match status {
"in_progress" => active_model.in_progress_at = Set(Some(now)),
"finalizing" => active_model.finalizing_at = Set(Some(now)),
"completed" => active_model.completed_at = Set(Some(now)),
"failed" => active_model.failed_at = Set(Some(now)),
"expired" => active_model.expired_at = Set(Some(now)),
"cancelling" => active_model.cancelling_at = Set(Some(now)),
"cancelled" => active_model.cancelled_at = Set(Some(now)),
_ => {}
}
active_model
.update(&txn)
.await
.map_err(GatewayError::from)?;
txn.commit().await.map_err(GatewayError::from)?;
Ok(())
}
pub async fn list_batches(
&self,
limit: Option<i32>,
after: Option<&str>,
) -> Result<Vec<crate::core::batch::BatchRecord>> {
debug!(
"Listing batches with limit: {:?}, after: {:?}",
limit, after
);
let mut query = entities::Batch::find();
if let Some(after_id) = after {
query = query.filter(entities::batch::Column::Id.gt(after_id));
}
if let Some(limit) = limit {
query = query.limit(limit as u64);
}
let batch_models = query
.order_by_desc(entities::batch::Column::CreatedAt)
.all(&self.db)
.await
.map_err(GatewayError::from)?;
let batch_records = batch_models
.into_iter()
.map(|model| {
let status = match model.status.as_str() {
"validating" => crate::core::batch::BatchStatus::Validating,
"failed" => crate::core::batch::BatchStatus::Failed,
"in_progress" => crate::core::batch::BatchStatus::InProgress,
"finalizing" => crate::core::batch::BatchStatus::Finalizing,
"completed" => crate::core::batch::BatchStatus::Completed,
"expired" => crate::core::batch::BatchStatus::Expired,
"cancelling" => crate::core::batch::BatchStatus::Cancelling,
"cancelled" => crate::core::batch::BatchStatus::Cancelled,
_ => crate::core::batch::BatchStatus::Failed,
};
crate::core::batch::BatchRecord {
id: model.id,
object: model.object,
endpoint: model.endpoint,
input_file_id: model.input_file_id,
completion_window: model.completion_window,
status,
output_file_id: model.output_file_id,
error_file_id: model.error_file_id,
created_at: model.created_at.with_timezone(&chrono::Utc),
in_progress_at: model.in_progress_at.map(|t| t.with_timezone(&chrono::Utc)),
expires_at: None, finalizing_at: model.finalizing_at.map(|t| t.with_timezone(&chrono::Utc)),
completed_at: model.completed_at.map(|t| t.with_timezone(&chrono::Utc)),
failed_at: model.failed_at.map(|t| t.with_timezone(&chrono::Utc)),
expired_at: model.expired_at.map(|t| t.with_timezone(&chrono::Utc)),
cancelling_at: model.cancelling_at.map(|t| t.with_timezone(&chrono::Utc)),
cancelled_at: model.cancelled_at.map(|t| t.with_timezone(&chrono::Utc)),
request_counts: crate::core::batch::BatchRequestCounts {
total: model.request_counts_total.unwrap_or(0),
completed: model.request_counts_completed.unwrap_or(0),
failed: model.request_counts_failed.unwrap_or(0),
},
metadata: model.metadata.and_then(|m| serde_json::from_str(&m).ok()),
}
})
.collect();
Ok(batch_records)
}
pub async fn get_batch_results(
&self,
_batch_id: &str,
) -> Result<Option<Vec<serde_json::Value>>> {
warn!("get_batch_results not implemented yet");
Ok(None)
}
pub async fn get_batch_request(
&self,
_batch_id: &str,
) -> Result<Option<crate::core::batch::BatchRequest>> {
warn!("get_batch_request not implemented yet");
Ok(None)
}
pub async fn store_batch_results(
&self,
_batch_id: &str,
_results: &[serde_json::Value],
) -> Result<()> {
warn!("store_batch_results not implemented yet");
Ok(())
}
pub async fn update_batch_progress(
&self,
batch_id: &str,
completed: i32,
failed: i32,
) -> Result<()> {
debug!(
"Updating batch progress: {} - completed: {}, failed: {}",
batch_id, completed, failed
);
let txn = self.db.begin().await.map_err(GatewayError::from)?;
let batch_model = entities::Batch::find_by_id(batch_id)
.one(&txn)
.await
.map_err(GatewayError::from)?
.ok_or_else(|| GatewayError::NotFound("Batch not found".to_string()))?;
let mut active_model: entities::batch::ActiveModel = batch_model.into();
active_model.request_counts_completed = Set(Some(completed));
active_model.request_counts_failed = Set(Some(failed));
active_model
.update(&txn)
.await
.map_err(GatewayError::from)?;
txn.commit().await.map_err(GatewayError::from)?;
Ok(())
}
pub async fn mark_batch_completed(&self, batch_id: &str) -> Result<()> {
debug!("Marking batch as completed: {}", batch_id);
let txn = self.db.begin().await.map_err(GatewayError::from)?;
let batch_model = entities::Batch::find_by_id(batch_id)
.one(&txn)
.await
.map_err(GatewayError::from)?
.ok_or_else(|| GatewayError::NotFound("Batch not found".to_string()))?;
let mut active_model: entities::batch::ActiveModel = batch_model.into();
active_model.status = Set("completed".to_string());
active_model.completed_at = Set(Some(chrono::Utc::now().into()));
active_model
.update(&txn)
.await
.map_err(GatewayError::from)?;
txn.commit().await.map_err(GatewayError::from)?;
Ok(())
}
}