use crate::ProviderError;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use crate::core::types::responses::FinishReason;
use crate::core::types::{
chat::ChatRequest, message::MessageContent, message::MessageRole, responses::ChatResponse,
};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BatchJob {
pub id: String,
pub status: BatchJobStatus,
pub created_at: i64,
pub updated_at: Option<i64>,
pub completed_at: Option<i64>,
pub input_config: BatchInputConfig,
pub output_config: BatchOutputConfig,
pub model: String,
pub generation_config: Option<GenerationConfig>,
pub error: Option<BatchError>,
pub statistics: Option<BatchStatistics>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub enum BatchJobStatus {
#[serde(rename = "JOB_STATE_PENDING")]
Pending,
#[serde(rename = "JOB_STATE_RUNNING")]
Running,
#[serde(rename = "JOB_STATE_SUCCEEDED")]
Succeeded,
#[serde(rename = "JOB_STATE_FAILED")]
Failed,
#[serde(rename = "JOB_STATE_CANCELLED")]
Cancelled,
#[serde(rename = "JOB_STATE_PARTIALLY_SUCCEEDED")]
PartiallySucceeded,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BatchInputConfig {
pub gcs_source: Option<GcsSource>,
pub bigquery_source: Option<BigQuerySource>,
pub instances_format: String, }
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BatchOutputConfig {
pub gcs_destination: Option<GcsDestination>,
pub bigquery_destination: Option<BigQueryDestination>,
pub predictions_format: String, }
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GcsSource {
pub uris: Vec<String>, }
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BigQuerySource {
pub input_uri: String, }
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GcsDestination {
pub output_uri_prefix: String, }
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BigQueryDestination {
pub output_uri: String, }
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GenerationConfig {
pub temperature: Option<f32>,
pub top_p: Option<f32>,
pub top_k: Option<i32>,
pub max_output_tokens: Option<i32>,
pub stop_sequences: Option<Vec<String>>,
pub response_mime_type: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BatchError {
pub code: i32,
pub message: String,
pub details: Option<Vec<BatchErrorDetail>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BatchErrorDetail {
pub error_type: String,
pub error_message: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BatchStatistics {
pub input_count: i64,
pub successful_count: i64,
pub failed_count: i64,
pub total_tokens: i64,
pub input_tokens: i64,
pub output_tokens: i64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CreateBatchJobRequest {
pub display_name: Option<String>,
pub model: String,
pub input_config: BatchInputConfig,
pub output_config: BatchOutputConfig,
pub generation_config: Option<GenerationConfig>,
}
pub struct BatchHandler;
impl BatchHandler {
pub fn new(_project_id: String, _location: String) -> Self {
Self
}
pub async fn create_batch_job(
&self,
request: CreateBatchJobRequest,
) -> Result<BatchJob, ProviderError> {
Ok(BatchJob {
id: uuid::Uuid::new_v4().to_string(),
status: BatchJobStatus::Pending,
created_at: chrono::Utc::now().timestamp(),
updated_at: None,
completed_at: None,
input_config: request.input_config,
output_config: request.output_config,
model: request.model,
generation_config: request.generation_config,
error: None,
statistics: None,
})
}
pub async fn get_batch_job(&self, _job_id: &str) -> Result<BatchJob, ProviderError> {
Err(ProviderError::not_supported(
"vertex_ai",
"Batch job retrieval not yet implemented",
))
}
pub async fn list_batch_jobs(
&self,
_filter: Option<String>,
_page_size: Option<i32>,
_page_token: Option<String>,
) -> Result<Vec<BatchJob>, ProviderError> {
Ok(Vec::new())
}
pub async fn cancel_batch_job(&self, _job_id: &str) -> Result<(), ProviderError> {
Ok(())
}
pub async fn delete_batch_job(&self, _job_id: &str) -> Result<(), ProviderError> {
Ok(())
}
}
pub fn transform_batch_request(
requests: Vec<ChatRequest>,
model: &str,
) -> Result<Vec<Value>, ProviderError> {
let mut batch_instances = Vec::new();
for request in requests {
let instance = if model.contains("gemini") {
transform_gemini_batch_instance(request)?
} else {
transform_default_batch_instance(request)?
};
batch_instances.push(instance);
}
Ok(batch_instances)
}
fn transform_gemini_batch_instance(request: ChatRequest) -> Result<Value, ProviderError> {
use crate::core::providers::vertex_ai::parse_vertex_model;
use crate::core::providers::vertex_ai::transformers::GeminiTransformer;
let transformer = GeminiTransformer::new();
let model = parse_vertex_model(&request.model);
transformer.transform_chat_request(&request, &model)
}
fn transform_default_batch_instance(request: ChatRequest) -> Result<Value, ProviderError> {
Ok(serde_json::json!({
"messages": request.messages.iter().map(|msg| {
serde_json::json!({
"role": msg.role.to_string().to_lowercase(),
"content": msg.content.as_ref().map(|c| c.to_string()).unwrap_or_default()
})
}).collect::<Vec<_>>(),
"parameters": {
"temperature": request.temperature,
"maxOutputTokens": request.max_tokens,
"topP": request.top_p,
}
}))
}
pub fn parse_batch_response(
response: Value,
model: &str,
) -> Result<Vec<ChatResponse>, ProviderError> {
let predictions = response["predictions"].as_array().ok_or_else(|| {
ProviderError::response_parsing("vertex_ai", "Missing predictions in batch response")
})?;
let mut responses = Vec::new();
for prediction in predictions {
let chat_response = if model.contains("gemini") {
parse_gemini_batch_response(prediction.clone(), model)?
} else {
parse_default_batch_response(prediction.clone(), model)?
};
responses.push(chat_response);
}
Ok(responses)
}
fn parse_gemini_batch_response(
response: Value,
model: &str,
) -> Result<ChatResponse, ProviderError> {
use crate::core::providers::vertex_ai::parse_vertex_model;
use crate::core::providers::vertex_ai::transformers::GeminiTransformer;
let transformer = GeminiTransformer::new();
let model_obj = parse_vertex_model(model);
transformer.transform_chat_response(response, &model_obj)
}
fn parse_default_batch_response(
response: Value,
model: &str,
) -> Result<ChatResponse, ProviderError> {
use crate::core::types::chat::ChatMessage;
use crate::core::types::responses::ChatChoice;
let content = response["content"]
.as_str()
.or_else(|| response["text"].as_str())
.or_else(|| response["output"].as_str())
.map(|s| s.to_string());
Ok(ChatResponse {
id: uuid::Uuid::new_v4().to_string(),
object: "chat.completion".to_string(),
created: chrono::Utc::now().timestamp(),
model: model.to_string(),
choices: vec![ChatChoice {
index: 0,
message: ChatMessage {
role: MessageRole::Assistant,
content: content.map(MessageContent::Text),
thinking: None,
name: None,
tool_calls: None,
function_call: None,
tool_call_id: None,
},
finish_reason: Some(FinishReason::Stop),
logprobs: None,
}],
usage: None,
system_fingerprint: None,
})
}