litellm-rs 0.4.16

A high-performance AI Gateway written in Rust, providing OpenAI-compatible APIs with intelligent routing, load balancing, and enterprise features
Documentation
//! Vertex AI Batch Processing Module

use crate::ProviderError;
use serde::{Deserialize, Serialize};
use serde_json::Value;

use crate::core::types::responses::FinishReason;
use crate::core::types::{
    chat::ChatRequest, message::MessageContent, message::MessageRole, responses::ChatResponse,
};

/// Batch job for processing multiple requests
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BatchJob {
    pub id: String,
    pub status: BatchJobStatus,
    pub created_at: i64,
    pub updated_at: Option<i64>,
    pub completed_at: Option<i64>,
    pub input_config: BatchInputConfig,
    pub output_config: BatchOutputConfig,
    pub model: String,
    pub generation_config: Option<GenerationConfig>,
    pub error: Option<BatchError>,
    pub statistics: Option<BatchStatistics>,
}

/// Batch job status
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub enum BatchJobStatus {
    #[serde(rename = "JOB_STATE_PENDING")]
    Pending,
    #[serde(rename = "JOB_STATE_RUNNING")]
    Running,
    #[serde(rename = "JOB_STATE_SUCCEEDED")]
    Succeeded,
    #[serde(rename = "JOB_STATE_FAILED")]
    Failed,
    #[serde(rename = "JOB_STATE_CANCELLED")]
    Cancelled,
    #[serde(rename = "JOB_STATE_PARTIALLY_SUCCEEDED")]
    PartiallySucceeded,
}

/// Batch input configuration
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BatchInputConfig {
    pub gcs_source: Option<GcsSource>,
    pub bigquery_source: Option<BigQuerySource>,
    pub instances_format: String, // "jsonl", "bigquery"
}

/// Batch output configuration
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BatchOutputConfig {
    pub gcs_destination: Option<GcsDestination>,
    pub bigquery_destination: Option<BigQueryDestination>,
    pub predictions_format: String, // "jsonl", "bigquery"
}

/// Google Cloud Storage source
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GcsSource {
    pub uris: Vec<String>, // gs://bucket/path/to/file.jsonl
}

/// BigQuery source
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BigQuerySource {
    pub input_uri: String, // bq://project.dataset.table
}

/// Google Cloud Storage destination
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GcsDestination {
    pub output_uri_prefix: String, // gs://bucket/path/to/output/
}

/// BigQuery destination
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BigQueryDestination {
    pub output_uri: String, // bq://project.dataset.table
}

/// Generation configuration for batch
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GenerationConfig {
    pub temperature: Option<f32>,
    pub top_p: Option<f32>,
    pub top_k: Option<i32>,
    pub max_output_tokens: Option<i32>,
    pub stop_sequences: Option<Vec<String>>,
    pub response_mime_type: Option<String>,
}

/// Batch error information
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BatchError {
    pub code: i32,
    pub message: String,
    pub details: Option<Vec<BatchErrorDetail>>,
}

/// Batch error detail
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BatchErrorDetail {
    pub error_type: String,
    pub error_message: String,
}

/// Batch statistics
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BatchStatistics {
    pub input_count: i64,
    pub successful_count: i64,
    pub failed_count: i64,
    pub total_tokens: i64,
    pub input_tokens: i64,
    pub output_tokens: i64,
}

/// Batch request for creating a job
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CreateBatchJobRequest {
    pub display_name: Option<String>,
    pub model: String,
    pub input_config: BatchInputConfig,
    pub output_config: BatchOutputConfig,
    pub generation_config: Option<GenerationConfig>,
}

/// Batch handler for managing batch jobs
pub struct BatchHandler;

impl BatchHandler {
    /// Create a new batch handler
    pub fn new(_project_id: String, _location: String) -> Self {
        Self
    }

    /// Create a batch job
    pub async fn create_batch_job(
        &self,
        request: CreateBatchJobRequest,
    ) -> Result<BatchJob, ProviderError> {
        // NOTE: actual batch job creation via Vertex AI API not yet implemented
        Ok(BatchJob {
            id: uuid::Uuid::new_v4().to_string(),
            status: BatchJobStatus::Pending,
            created_at: chrono::Utc::now().timestamp(),
            updated_at: None,
            completed_at: None,
            input_config: request.input_config,
            output_config: request.output_config,
            model: request.model,
            generation_config: request.generation_config,
            error: None,
            statistics: None,
        })
    }

    /// Get batch job status
    pub async fn get_batch_job(&self, _job_id: &str) -> Result<BatchJob, ProviderError> {
        // NOTE: actual job retrieval not yet implemented
        Err(ProviderError::not_supported(
            "vertex_ai",
            "Batch job retrieval not yet implemented",
        ))
    }

    /// List batch jobs
    pub async fn list_batch_jobs(
        &self,
        _filter: Option<String>,
        _page_size: Option<i32>,
        _page_token: Option<String>,
    ) -> Result<Vec<BatchJob>, ProviderError> {
        // NOTE: actual job listing not yet implemented
        Ok(Vec::new())
    }

    /// Cancel a batch job
    pub async fn cancel_batch_job(&self, _job_id: &str) -> Result<(), ProviderError> {
        // NOTE: actual job cancellation not yet implemented
        Ok(())
    }

    /// Delete a batch job
    pub async fn delete_batch_job(&self, _job_id: &str) -> Result<(), ProviderError> {
        // NOTE: actual job deletion not yet implemented
        Ok(())
    }
}

/// Transform batch request to Vertex AI format
pub fn transform_batch_request(
    requests: Vec<ChatRequest>,
    model: &str,
) -> Result<Vec<Value>, ProviderError> {
    let mut batch_instances = Vec::new();

    for request in requests {
        // Transform each request to the appropriate format
        let instance = if model.contains("gemini") {
            transform_gemini_batch_instance(request)?
        } else {
            transform_default_batch_instance(request)?
        };

        batch_instances.push(instance);
    }

    Ok(batch_instances)
}

/// Transform single request to Gemini batch format
fn transform_gemini_batch_instance(request: ChatRequest) -> Result<Value, ProviderError> {
    use crate::core::providers::vertex_ai::parse_vertex_model;
    use crate::core::providers::vertex_ai::transformers::GeminiTransformer;

    let transformer = GeminiTransformer::new();
    let model = parse_vertex_model(&request.model);

    transformer.transform_chat_request(&request, &model)
}

/// Transform single request to default batch format
fn transform_default_batch_instance(request: ChatRequest) -> Result<Value, ProviderError> {
    Ok(serde_json::json!({
        "messages": request.messages.iter().map(|msg| {
            serde_json::json!({
                "role": msg.role.to_string().to_lowercase(),
                "content": msg.content.as_ref().map(|c| c.to_string()).unwrap_or_default()
            })
        }).collect::<Vec<_>>(),
        "parameters": {
            "temperature": request.temperature,
            "maxOutputTokens": request.max_tokens,
            "topP": request.top_p,
        }
    }))
}

/// Parse batch response from Vertex AI
pub fn parse_batch_response(
    response: Value,
    model: &str,
) -> Result<Vec<ChatResponse>, ProviderError> {
    let predictions = response["predictions"].as_array().ok_or_else(|| {
        ProviderError::response_parsing("vertex_ai", "Missing predictions in batch response")
    })?;

    let mut responses = Vec::new();

    for prediction in predictions {
        let chat_response = if model.contains("gemini") {
            parse_gemini_batch_response(prediction.clone(), model)?
        } else {
            parse_default_batch_response(prediction.clone(), model)?
        };

        responses.push(chat_response);
    }

    Ok(responses)
}

/// Parse Gemini batch response
fn parse_gemini_batch_response(
    response: Value,
    model: &str,
) -> Result<ChatResponse, ProviderError> {
    use crate::core::providers::vertex_ai::parse_vertex_model;
    use crate::core::providers::vertex_ai::transformers::GeminiTransformer;

    let transformer = GeminiTransformer::new();
    let model_obj = parse_vertex_model(model);

    transformer.transform_chat_response(response, &model_obj)
}

/// Parse default batch response
fn parse_default_batch_response(
    response: Value,
    model: &str,
) -> Result<ChatResponse, ProviderError> {
    use crate::core::types::chat::ChatMessage;
    use crate::core::types::responses::ChatChoice;

    let content = response["content"]
        .as_str()
        .or_else(|| response["text"].as_str())
        .or_else(|| response["output"].as_str())
        .map(|s| s.to_string());

    Ok(ChatResponse {
        id: uuid::Uuid::new_v4().to_string(),
        object: "chat.completion".to_string(),
        created: chrono::Utc::now().timestamp(),
        model: model.to_string(),
        choices: vec![ChatChoice {
            index: 0,
            message: ChatMessage {
                role: MessageRole::Assistant,
                content: content.map(MessageContent::Text),
                thinking: None,
                name: None,
                tool_calls: None,
                function_call: None,
                tool_call_id: None,
            },
            finish_reason: Some(FinishReason::Stop),
            logprobs: None,
        }],
        usage: None,
        system_fingerprint: None,
    })
}