use serde::{Deserialize, Serialize};
use snafu::Snafu;
use time::OffsetDateTime;
use crate::Model;
use crate::common::serde::*;
use crate::generation::{GenerateContentRequest, GenerationResponse};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BatchRequestFileItem {
pub request: GenerateContentRequest,
#[serde(with = "key_as_string")]
pub key: usize,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BatchResponseFileItem {
#[serde(flatten)]
pub response: BatchGenerateContentResponseItem,
#[serde(with = "key_as_string")]
pub key: usize,
}
impl From<BatchGenerateContentResponseItem> for Result<GenerationResponse, IndividualRequestError> {
fn from(response: BatchGenerateContentResponseItem) -> Self {
match response {
BatchGenerateContentResponseItem::Response(r) => Ok(r),
BatchGenerateContentResponseItem::Error(err) => Err(err),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)]
pub enum BatchOperationResponse {
#[serde(rename_all = "camelCase")]
InlinedResponses { inlined_responses: InlinedResponses },
#[serde(rename_all = "camelCase")]
ResponsesFile { responses_file: String },
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct InlinedResponses {
pub inlined_responses: Vec<InlinedBatchGenerationResponseItem>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct InlinedBatchGenerationResponseItem {
pub metadata: RequestMetadata,
#[serde(flatten)]
pub result: BatchGenerateContentResponseItem,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub enum BatchGenerateContentResponseItem {
Response(GenerationResponse),
Error(IndividualRequestError),
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct IndividualRequestError {
pub code: i32,
pub message: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub details: Option<serde_json::Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct BatchGenerateContentResponse {
pub name: String,
pub metadata: BatchMetadata,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct BatchMetadata {
#[serde(rename = "@type")]
pub type_annotation: String,
pub model: Model,
pub display_name: String,
#[serde(with = "time::serde::rfc3339")]
pub create_time: OffsetDateTime,
#[serde(with = "time::serde::rfc3339")]
pub update_time: OffsetDateTime,
pub batch_stats: BatchStats,
pub state: BatchState,
pub name: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct BatchRequestItem {
pub request: GenerateContentRequest,
pub metadata: RequestMetadata,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct BatchGenerateContentRequest {
pub batch: BatchConfig,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct BatchConfig {
pub display_name: String,
pub input_config: InputConfig,
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
#[allow(clippy::enum_variant_names)]
pub enum BatchState {
BatchStateUnspecified,
BatchStatePending,
BatchStateRunning,
BatchStateSucceeded,
BatchStateFailed,
BatchStateCancelled,
BatchStateExpired,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct BatchStats {
#[serde(with = "i64_as_string")]
pub request_count: i64,
#[serde(default, with = "i64_as_string::optional")]
pub pending_request_count: Option<i64>,
#[serde(default, with = "i64_as_string::optional")]
pub completed_request_count: Option<i64>,
#[serde(default, with = "i64_as_string::optional")]
pub failed_request_count: Option<i64>,
#[serde(default, with = "i64_as_string::optional")]
pub successful_request_count: Option<i64>,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct BatchOperation {
pub name: String,
pub metadata: BatchMetadata,
#[serde(default)]
pub done: bool,
#[serde(flatten)]
pub result: Option<OperationResult>,
}
#[derive(Debug, Snafu, serde::Deserialize, serde::Serialize)]
pub struct OperationError {
pub code: i32,
pub message: String,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub enum OperationResult {
Response(BatchOperationResponse),
Error(OperationError),
}
impl From<OperationResult> for Result<BatchOperationResponse, OperationError> {
fn from(operation: OperationResult) -> Self {
match operation {
OperationResult::Response(response) => Ok(response),
OperationResult::Error(error) => Err(error),
}
}
}
#[derive(Debug, serde::Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ListBatchesResponse {
pub operations: Vec<BatchOperation>,
pub next_page_token: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub enum InputConfig {
Requests(RequestsContainer),
FileName(String),
}
impl InputConfig {
pub fn batch_size(&self) -> Option<usize> {
match self {
InputConfig::Requests(container) => Some(container.requests.len()),
InputConfig::FileName(_) => None,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct RequestsContainer {
pub requests: Vec<BatchRequestItem>,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct RequestMetadata {
#[serde(with = "key_as_string")]
pub key: usize,
}